Laurie commited on
Commit
abbcb88
1 Parent(s): aa889d6

Add src folder

Browse files
src/__init__.py ADDED
File without changes
src/api_demo.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Implements API for fine-tuned models.
3
+ # Usage: python api_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint
4
+
5
+ # Request:
6
+ # curl http://127.0.0.1:8000 --header 'Content-Type: application/json' --data '{"prompt": "Hello there!", "history": []}'
7
+
8
+ # Response:
9
+ # {
10
+ # "response": "'Hi there!'",
11
+ # "history": "[('Hello there!', 'Hi there!')]",
12
+ # "status": 200,
13
+ # "time": "2000-00-00 00:00:00"
14
+ # }
15
+
16
+
17
+ import json
18
+ import torch
19
+ import uvicorn
20
+ import datetime
21
+ from fastapi import FastAPI, Request
22
+
23
+ from utils import (
24
+ Template,
25
+ load_pretrained,
26
+ prepare_infer_args,
27
+ get_logits_processor
28
+ )
29
+
30
+
31
+ def torch_gc():
32
+ if torch.cuda.is_available():
33
+ num_gpus = torch.cuda.device_count()
34
+ for device_id in range(num_gpus):
35
+ with torch.cuda.device(device_id):
36
+ torch.cuda.empty_cache()
37
+ torch.cuda.ipc_collect()
38
+
39
+
40
+ app = FastAPI()
41
+
42
+
43
+ @app.post("/")
44
+ async def create_item(request: Request):
45
+ global model, tokenizer, prompt_template, generating_args
46
+
47
+ # Parse the request JSON
48
+ json_post_raw = await request.json()
49
+ json_post = json.dumps(json_post_raw)
50
+ json_post_list = json.loads(json_post)
51
+ prompt = json_post_list.get("prompt")
52
+ history = json_post_list.get("history")
53
+ max_new_tokens = json_post_list.get("max_new_tokens", None)
54
+ top_p = json_post_list.get("top_p", None)
55
+ temperature = json_post_list.get("temperature", None)
56
+
57
+ # Tokenize the input prompt
58
+ input_ids = tokenizer([prompt_template.get_prompt(prompt, history)], return_tensors="pt")["input_ids"]
59
+ input_ids = input_ids.to(model.device)
60
+
61
+ # Generation arguments
62
+ gen_kwargs = generating_args.to_dict()
63
+ gen_kwargs["input_ids"] = input_ids
64
+ gen_kwargs["logits_processor"] = get_logits_processor()
65
+ gen_kwargs["max_new_tokens"] = max_new_tokens if max_new_tokens else gen_kwargs["max_new_tokens"]
66
+ gen_kwargs["top_p"] = top_p if top_p else gen_kwargs["top_p"]
67
+ gen_kwargs["temperature"] = temperature if temperature else gen_kwargs["temperature"]
68
+
69
+ # Generate response
70
+ with torch.no_grad():
71
+ generation_output = model.generate(**gen_kwargs)
72
+ outputs = generation_output.tolist()[0][len(input_ids[0]):]
73
+ response = tokenizer.decode(outputs, skip_special_tokens=True)
74
+
75
+ # Update history
76
+ history = history + [(prompt, response)]
77
+
78
+ # Prepare response
79
+ now = datetime.datetime.now()
80
+ time = now.strftime("%Y-%m-%d %H:%M:%S")
81
+ answer = {
82
+ "response": repr(response),
83
+ "history": repr(history),
84
+ "status": 200,
85
+ "time": time
86
+ }
87
+
88
+ # Log and clean up
89
+ log = "[" + time + "] " + "\", prompt:\"" + prompt + "\", response:\"" + repr(response) + "\""
90
+ print(log)
91
+ torch_gc()
92
+
93
+ return answer
94
+
95
+
96
+ if __name__ == "__main__":
97
+ model_args, data_args, finetuning_args, generating_args = prepare_infer_args()
98
+ model, tokenizer = load_pretrained(model_args, finetuning_args)
99
+ prompt_template = Template(data_args.prompt_template)
100
+
101
+ uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)
src/cli_demo.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Implements stream chat in command line for fine-tuned models.
3
+ # Usage: python cli_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint
4
+
5
+
6
+ from utils import (
7
+ Template,
8
+ load_pretrained,
9
+ prepare_infer_args,
10
+ get_logits_processor
11
+ )
12
+ from threading import Thread
13
+ from transformers import TextIteratorStreamer
14
+
15
+
16
+ def main():
17
+
18
+ model_args, data_args, finetuning_args, generating_args = prepare_infer_args()
19
+ model, tokenizer = load_pretrained(model_args, finetuning_args)
20
+
21
+ model_name = "BLOOM" if "bloom" in model_args.model_name_or_path else "LLaMA"
22
+ prompt_template = Template(data_args.prompt_template)
23
+
24
+ def predict_and_print(query, history: list) -> list:
25
+ input_ids = tokenizer([prompt_template.get_prompt(query, history)], return_tensors="pt")["input_ids"]
26
+ input_ids = input_ids.to(model.device)
27
+
28
+ streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
29
+
30
+ gen_kwargs = generating_args.to_dict()
31
+ gen_kwargs["input_ids"] = input_ids
32
+ gen_kwargs["logits_processor"] = get_logits_processor()
33
+ gen_kwargs["streamer"] = streamer
34
+
35
+ thread = Thread(target=model.generate, kwargs=gen_kwargs)
36
+ thread.start()
37
+
38
+ print("{}: ".format(model_name), end="", flush=True)
39
+ response = ""
40
+ for new_text in streamer:
41
+ print(new_text, end="", flush=True)
42
+ response += new_text
43
+ print()
44
+ history = history + [(query, response)]
45
+ return history
46
+
47
+ history = []
48
+ print("欢迎使用 {} 模型,输入内容即可对话,clear清空对话历史,stop终止程序".format(model_name))
49
+ while True:
50
+ try:
51
+ query = input("\nInput: ")
52
+ except UnicodeDecodeError:
53
+ print("Detected decoding error at the inputs, please set the terminal encoding to utf-8.")
54
+ continue
55
+ except Exception:
56
+ raise
57
+
58
+ if query.strip() == "stop":
59
+ break
60
+
61
+ if query.strip() == "clear":
62
+ history = []
63
+ print("History has been removed.")
64
+ continue
65
+
66
+ history = predict_and_print(query, history)
67
+
68
+
69
+ if __name__ == "__main__":
70
+ main()
src/export_model.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Exports the fine-tuned model.
3
+ # Usage: python export_model.py --checkpoint_dir path_to_checkpoint --output_dir path_to_save_model
4
+
5
+
6
+ from utils import load_pretrained, prepare_args
7
+
8
+
9
+ def main():
10
+
11
+ model_args, _, training_args, finetuning_args = prepare_args(stage="sft")
12
+ model, tokenizer = load_pretrained(model_args, finetuning_args)
13
+ model.save_pretrained(training_args.output_dir, max_shard_size="10GB")
14
+ tokenizer.save_pretrained(training_args.output_dir)
15
+ print("model and tokenizer have been saved at:", training_args.output_dir)
16
+
17
+
18
+ if __name__ == "__main__":
19
+ main()
src/train_ppo.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Implements parameter-efficient PPO training of fine-tuned models.
3
+ # This code is inspired by:
4
+ # https://github.com/lvwerra/trl/blob/main/examples/sentiment/scripts/gpt-neox-20b_peft/gpt-neo-20b_sentiment_peft.py
5
+
6
+ import math
7
+
8
+ from torch.optim import AdamW
9
+ from transformers.optimization import get_scheduler
10
+ from trl import PPOConfig
11
+
12
+ from utils import (
13
+ DynamicDataCollatorWithPadding,
14
+ PPOPeftTrainer,
15
+ LogCallback,
16
+ load_pretrained,
17
+ prepare_args,
18
+ prepare_data,
19
+ preprocess_data,
20
+ plot_loss
21
+ )
22
+
23
+
24
+ def main():
25
+
26
+ # Prepare pretrained model and dataset
27
+ model_args, data_args, training_args, finetuning_args = prepare_args(stage="ppo")
28
+ dataset = prepare_data(model_args, data_args)
29
+ model, tokenizer = load_pretrained(model_args, finetuning_args, training_args.do_train, stage="ppo")
30
+ dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="ppo")
31
+ data_collator = DynamicDataCollatorWithPadding(tokenizer)
32
+
33
+ ppo_config = PPOConfig(
34
+ model_name=model_args.model_name_or_path,
35
+ learning_rate=training_args.learning_rate,
36
+ mini_batch_size=training_args.per_device_train_batch_size,
37
+ batch_size=training_args.per_device_train_batch_size,
38
+ gradient_accumulation_steps=training_args.gradient_accumulation_steps,
39
+ ppo_epochs=1,
40
+ max_grad_norm=training_args.max_grad_norm
41
+ )
42
+
43
+ optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=ppo_config.learning_rate)
44
+ total_train_batch_size = \
45
+ training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps * training_args.world_size
46
+ lr_scheduler = get_scheduler(
47
+ training_args.lr_scheduler_type,
48
+ optimizer=optimizer,
49
+ num_warmup_steps=training_args.warmup_steps,
50
+ num_training_steps=(training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size))
51
+ )
52
+
53
+ # Initialize our Trainer
54
+ ppo_trainer = PPOPeftTrainer(
55
+ training_args=training_args,
56
+ finetuning_args=finetuning_args,
57
+ callbacks=[LogCallback()],
58
+ config=ppo_config,
59
+ model=model,
60
+ ref_model=None,
61
+ tokenizer=tokenizer,
62
+ dataset=dataset,
63
+ data_collator=data_collator,
64
+ optimizer=optimizer,
65
+ lr_scheduler=lr_scheduler
66
+ )
67
+
68
+ ppo_trainer.ppo_train(max_target_length=data_args.max_target_length)
69
+ ppo_trainer.save_model()
70
+ ppo_trainer.save_state() # must be after save_model
71
+ if ppo_trainer.is_world_process_zero() and model_args.plot_loss:
72
+ plot_loss(training_args.output_dir, keys=["loss", "reward"])
73
+
74
+
75
+ def _mp_fn(index):
76
+ # For xla_spawn (TPUs)
77
+ main()
78
+
79
+
80
+ if __name__ == "__main__":
81
+ main()
src/train_pt.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Implements several parameter-efficient pre-training method.
3
+ # This code is inspired by
4
+ # https://github.com/huggingface/transformers/blob/v4.29.2/examples/pytorch/language-modeling/run_clm.py
5
+
6
+
7
+ import math
8
+
9
+ from utils import (
10
+ DynamicDataCollatorWithPadding,
11
+ PeftTrainer,
12
+ LogCallback,
13
+ load_pretrained,
14
+ prepare_args,
15
+ prepare_data,
16
+ preprocess_data,
17
+ plot_loss
18
+ )
19
+
20
+
21
+ def main():
22
+
23
+ # Prepare pretrained model and dataset
24
+ model_args, data_args, training_args, finetuning_args = prepare_args(stage="pt")
25
+ dataset = prepare_data(model_args, data_args)
26
+ model, tokenizer = load_pretrained(model_args, finetuning_args, training_args.do_train, stage="pt")
27
+ dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="pt")
28
+ data_collator = DynamicDataCollatorWithPadding(tokenizer, data_args.ignore_pad_token_for_loss)
29
+
30
+ # Split the dataset
31
+ if training_args.do_train:
32
+ if data_args.dev_ratio > 1e-6:
33
+ dataset = dataset.train_test_split(test_size=data_args.dev_ratio)
34
+ trainer_kwargs = {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
35
+ else:
36
+ trainer_kwargs = {"train_dataset": dataset}
37
+ else: # do_eval or do_predict
38
+ trainer_kwargs = {"eval_dataset": dataset}
39
+
40
+ # Initialize our Trainer
41
+ trainer = PeftTrainer(
42
+ finetuning_args=finetuning_args,
43
+ model=model,
44
+ args=training_args,
45
+ tokenizer=tokenizer,
46
+ data_collator=data_collator,
47
+ callbacks=[LogCallback()],
48
+ **trainer_kwargs
49
+ )
50
+
51
+ # Training
52
+ if training_args.do_train:
53
+ train_result = trainer.train()
54
+ trainer.log_metrics("train", train_result.metrics)
55
+ trainer.save_metrics("train", train_result.metrics)
56
+ trainer.save_state()
57
+ trainer.save_model()
58
+ if trainer.is_world_process_zero() and model_args.plot_loss:
59
+ plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
60
+
61
+ # Evaluation
62
+ if training_args.do_eval:
63
+ metrics = trainer.evaluate(metric_key_prefix="eval")
64
+
65
+ try:
66
+ perplexity = math.exp(metrics["eval_loss"])
67
+ except OverflowError:
68
+ perplexity = float("inf")
69
+ metrics["perplexity"] = perplexity
70
+
71
+ trainer.log_metrics("eval", metrics)
72
+ trainer.save_metrics("eval", metrics)
73
+
74
+
75
+ def _mp_fn(index):
76
+ # For xla_spawn (TPUs)
77
+ main()
78
+
79
+
80
+ if __name__ == "__main__":
81
+ main()
src/train_rm.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Implements parameter-efficient training of reward models.
3
+ # This code is inspired by:
4
+ # https://github.com/lvwerra/trl/blob/main/examples/summarization/scripts/reward_summarization.py
5
+ # https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py
6
+
7
+
8
+ from utils import (
9
+ PairwiseDataCollatorWithPadding,
10
+ PairwisePeftTrainer,
11
+ LogCallback,
12
+ load_pretrained,
13
+ prepare_args,
14
+ prepare_data,
15
+ preprocess_data,
16
+ compute_accuracy,
17
+ plot_loss
18
+ )
19
+
20
+ def main():
21
+
22
+ # Prepare pretrained model and dataset
23
+ model_args, data_args, training_args, finetuning_args = prepare_args(stage="rm")
24
+ dataset = prepare_data(model_args, data_args)
25
+ model, tokenizer = load_pretrained(model_args, finetuning_args, training_args.do_train, stage="rm")
26
+ dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="rm")
27
+ data_collator = PairwiseDataCollatorWithPadding(tokenizer)
28
+
29
+ training_args.remove_unused_columns = False # important for pairwise dataset
30
+
31
+ # Split the dataset
32
+ if training_args.do_train:
33
+ if data_args.dev_ratio > 1e-6:
34
+ dataset = dataset.train_test_split(test_size=data_args.dev_ratio)
35
+ trainer_kwargs = {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
36
+ else:
37
+ trainer_kwargs = {"train_dataset": dataset}
38
+ else: # do_eval or do_predict
39
+ trainer_kwargs = {"eval_dataset": dataset}
40
+
41
+ # Initialize our Trainer
42
+ trainer = PairwisePeftTrainer(
43
+ finetuning_args=finetuning_args,
44
+ model=model,
45
+ args=training_args,
46
+ tokenizer=tokenizer,
47
+ data_collator=data_collator,
48
+ callbacks=[LogCallback()],
49
+ compute_metrics=compute_accuracy,
50
+ **trainer_kwargs
51
+ )
52
+
53
+ # Training
54
+ if training_args.do_train:
55
+ train_result = trainer.train()
56
+ trainer.log_metrics("train", train_result.metrics)
57
+ trainer.save_metrics("train", train_result.metrics)
58
+ trainer.save_state()
59
+ trainer.save_model()
60
+ if trainer.is_world_process_zero() and model_args.plot_loss:
61
+ plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
62
+
63
+ # Evaluation
64
+ if training_args.do_eval:
65
+ metrics = trainer.evaluate(metric_key_prefix="eval")
66
+ trainer.log_metrics("eval", metrics)
67
+ trainer.save_metrics("eval", metrics)
68
+
69
+
70
+ def _mp_fn(index):
71
+ # For xla_spawn (TPUs)
72
+ main()
73
+
74
+
75
+ if __name__ == "__main__":
76
+ main()
src/train_sft.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Implements several parameter-efficient supervised fine-tuning method.
3
+ # This code is inspired by
4
+ # https://github.com/huggingface/transformers/blob/v4.29.2/examples/pytorch/summarization/run_summarization.py
5
+
6
+
7
+ from utils import (
8
+ DynamicDataCollatorWithPadding,
9
+ Seq2SeqPeftTrainer,
10
+ ComputeMetrics,
11
+ LogCallback,
12
+ load_pretrained,
13
+ prepare_args,
14
+ prepare_data,
15
+ preprocess_data,
16
+ get_logits_processor,
17
+ plot_loss
18
+ )
19
+
20
+
21
+ def main():
22
+
23
+ # Prepare pretrained model and dataset
24
+ model_args, data_args, training_args, finetuning_args = prepare_args(stage="sft")
25
+ dataset = prepare_data(model_args, data_args)
26
+ model, tokenizer = load_pretrained(model_args, finetuning_args, training_args.do_train, stage="sft")
27
+ dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="sft")
28
+ data_collator = DynamicDataCollatorWithPadding(tokenizer, data_args.ignore_pad_token_for_loss)
29
+
30
+ # Override the decoding parameters of Seq2SeqTrainer
31
+ training_args.generation_max_length = training_args.generation_max_length if \
32
+ training_args.generation_max_length is not None else data_args.max_target_length
33
+ training_args.generation_num_beams = data_args.eval_num_beams if \
34
+ data_args.eval_num_beams is not None else training_args.generation_num_beams
35
+
36
+ # Split the dataset
37
+ if training_args.do_train:
38
+ if data_args.dev_ratio > 1e-6:
39
+ dataset = dataset.train_test_split(test_size=data_args.dev_ratio)
40
+ trainer_kwargs = {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
41
+ else:
42
+ trainer_kwargs = {"train_dataset": dataset}
43
+ else: # do_eval or do_predict
44
+ trainer_kwargs = {"eval_dataset": dataset}
45
+
46
+ # Initialize our Trainer
47
+ trainer = Seq2SeqPeftTrainer(
48
+ finetuning_args=finetuning_args,
49
+ model=model,
50
+ args=training_args,
51
+ tokenizer=tokenizer,
52
+ data_collator=data_collator,
53
+ callbacks=[LogCallback()],
54
+ compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None,
55
+ **trainer_kwargs
56
+ )
57
+
58
+ # Keyword arguments for `model.generate`
59
+ gen_kwargs = {
60
+ "do_sample": True,
61
+ "top_p": 0.7,
62
+ "max_new_tokens": data_args.max_target_length + 1,
63
+ "temperature": 0.95,
64
+ "logits_processor": get_logits_processor()
65
+ }
66
+
67
+ # Training
68
+ if training_args.do_train:
69
+ train_result = trainer.train()
70
+ trainer.log_metrics("train", train_result.metrics)
71
+ trainer.save_metrics("train", train_result.metrics)
72
+ trainer.save_state()
73
+ trainer.save_model()
74
+ if trainer.is_world_process_zero() and model_args.plot_loss:
75
+ plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
76
+
77
+ # Evaluation
78
+ if training_args.do_eval:
79
+ metrics = trainer.evaluate(metric_key_prefix="eval", **gen_kwargs)
80
+ trainer.log_metrics("eval", metrics)
81
+ trainer.save_metrics("eval", metrics)
82
+
83
+ # Predict
84
+ if training_args.do_predict:
85
+ predict_results = trainer.predict(dataset, metric_key_prefix="predict", **gen_kwargs)
86
+ trainer.log_metrics("predict", predict_results.metrics)
87
+ trainer.save_metrics("predict", predict_results.metrics)
88
+ trainer.save_predictions(predict_results, tokenizer)
89
+
90
+
91
+ def _mp_fn(index):
92
+ # For xla_spawn (TPUs)
93
+ main()
94
+
95
+
96
+ if __name__ == "__main__":
97
+ main()
src/utils/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .common import (
2
+ load_pretrained,
3
+ prepare_args,
4
+ prepare_infer_args,
5
+ prepare_data,
6
+ preprocess_data
7
+ )
8
+
9
+ from .data_collator import DynamicDataCollatorWithPadding
10
+
11
+ from .peft_trainer import PeftTrainer, LogCallback
12
+
13
+ from .seq2seq import ComputeMetrics, Seq2SeqPeftTrainer
14
+ from .pairwise import PairwiseDataCollatorWithPadding, PairwisePeftTrainer, compute_accuracy
15
+ from .ppo import PPOPeftTrainer
16
+
17
+ from .template import Template
18
+
19
+ from .other import get_logits_processor, plot_loss
src/utils/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (785 Bytes). View file
 
src/utils/__pycache__/common.cpython-310.pyc ADDED
Binary file (16 kB). View file
 
src/utils/__pycache__/config.cpython-310.pyc ADDED
Binary file (11.3 kB). View file
 
src/utils/__pycache__/data_collator.cpython-310.pyc ADDED
Binary file (2.99 kB). View file
 
src/utils/__pycache__/other.cpython-310.pyc ADDED
Binary file (7.37 kB). View file
 
src/utils/__pycache__/pairwise.cpython-310.pyc ADDED
Binary file (2.93 kB). View file
 
src/utils/__pycache__/peft_trainer.cpython-310.pyc ADDED
Binary file (5.13 kB). View file
 
src/utils/__pycache__/ppo.cpython-310.pyc ADDED
Binary file (7.01 kB). View file
 
src/utils/__pycache__/seq2seq.cpython-310.pyc ADDED
Binary file (4.15 kB). View file
 
src/utils/__pycache__/template.cpython-310.pyc ADDED
Binary file (3.12 kB). View file
 
src/utils/common.py ADDED
@@ -0,0 +1,561 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import torch
4
+ import hashlib
5
+ from itertools import chain
6
+ from typing import List, Literal, Optional, Tuple
7
+
8
+ import transformers
9
+ from transformers import (
10
+ AutoConfig,
11
+ AutoModelForCausalLM,
12
+ AutoTokenizer,
13
+ HfArgumentParser,
14
+ Seq2SeqTrainingArguments,
15
+ BitsAndBytesConfig
16
+ )
17
+ from transformers.utils import check_min_version
18
+ from transformers.utils.versions import require_version
19
+ from transformers.modeling_utils import PreTrainedModel
20
+ from transformers.tokenization_utils import PreTrainedTokenizer
21
+
22
+ import datasets
23
+ from datasets import Dataset, concatenate_datasets, load_dataset
24
+
25
+ from peft import (
26
+ PeftModel,
27
+ TaskType,
28
+ LoraConfig,
29
+ get_peft_model
30
+ )
31
+
32
+ from peft.utils import CONFIG_NAME, WEIGHTS_NAME
33
+
34
+ from trl import AutoModelForCausalLMWithValueHead
35
+
36
+ from .config import (
37
+ ModelArguments,
38
+ DataTrainingArguments,
39
+ FinetuningArguments,
40
+ GeneratingArguments
41
+ )
42
+
43
+ from .template import Template
44
+
45
+ from .other import (
46
+ get_logger,
47
+ load_trainable_params,
48
+ load_valuehead_params,
49
+ print_trainable_params,
50
+ prepare_model_for_training,
51
+ IGNORE_INDEX
52
+ )
53
+
54
+ check_min_version("4.29.1")
55
+ require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0")
56
+ require_version("accelerate>=0.19.0", "To fix: pip install accelerate>=0.19.0")
57
+ require_version("peft>=0.3.0", "To fix: pip install peft>=0.3.0")
58
+ require_version("trl>=0.4.4", "To fix: pip install trl>=0.4.4")
59
+
60
+
61
+ logger = get_logger(__name__)
62
+
63
+
64
+ def _init_adapter(
65
+ model: PreTrainedModel,
66
+ model_args: ModelArguments,
67
+ finetuning_args: FinetuningArguments,
68
+ is_trainable: bool,
69
+ is_mergeable: bool
70
+ ) -> PreTrainedModel:
71
+ r"""
72
+ Initializes the adapters.
73
+
74
+ Support full-parameter, freeze and LoRA training.
75
+
76
+ Note that the trainable parameters must be cast to float32.
77
+ """
78
+
79
+ if finetuning_args.finetuning_type == "none" and is_trainable:
80
+ raise ValueError("You cannot use finetuning_type=none while training.")
81
+
82
+ if finetuning_args.finetuning_type == "full":
83
+ logger.info("Fine-tuning method: Full")
84
+ model = model.float()
85
+
86
+ if finetuning_args.finetuning_type == "freeze":
87
+ logger.info("Fine-tuning method: Freeze")
88
+ for name, param in model.named_parameters():
89
+ if not any(trainable_layer in name for trainable_layer in finetuning_args.trainable_layers):
90
+ param.requires_grad_(False)
91
+ else:
92
+ param.data = param.data.to(torch.float32)
93
+
94
+ if model_args.checkpoint_dir is not None:
95
+ if finetuning_args.finetuning_type != "lora":
96
+ assert is_mergeable and len(model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints."
97
+ assert load_trainable_params(model, model_args.checkpoint_dir[0]), "Model checkpoint is not correctly loaded."
98
+ else:
99
+ assert is_mergeable or len(model_args.checkpoint_dir) == 1, "Quantized model only accepts a single checkpoint."
100
+
101
+ if finetuning_args.finetuning_type == "lora":
102
+ logger.info("Fine-tuning method: LoRA")
103
+ lastest_checkpoint = None
104
+
105
+ if model_args.checkpoint_dir is not None:
106
+ if os.path.exists(os.path.join(model_args.checkpoint_dir[0], WEIGHTS_NAME)) and \
107
+ not os.path.exists(os.path.join(model_args.checkpoint_dir[0], CONFIG_NAME)):
108
+ raise ValueError("The given checkpoint may be not a LoRA checkpoint, \
109
+ please specify `--finetuning_type full/freeze` instead.")
110
+
111
+ if (is_trainable and model_args.resume_lora_training) or (not is_mergeable): # continually train on the lora weights
112
+ checkpoints_to_merge, lastest_checkpoint = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1]
113
+ else:
114
+ checkpoints_to_merge = model_args.checkpoint_dir
115
+
116
+ for checkpoint in checkpoints_to_merge:
117
+ model = PeftModel.from_pretrained(model, checkpoint)
118
+ model = model.merge_and_unload()
119
+
120
+ if len(checkpoints_to_merge) > 0:
121
+ logger.info("Merged {} model checkpoint(s).".format(len(checkpoints_to_merge)))
122
+
123
+ if lastest_checkpoint is not None: # resume lora training or quantized inference
124
+ model = PeftModel.from_pretrained(model, lastest_checkpoint, is_trainable=is_trainable)
125
+
126
+ if is_trainable and lastest_checkpoint is None: # create new lora weights while training
127
+ lora_config = LoraConfig(
128
+ task_type=TaskType.CAUSAL_LM,
129
+ inference_mode=False,
130
+ r=finetuning_args.lora_rank,
131
+ lora_alpha=finetuning_args.lora_alpha,
132
+ lora_dropout=finetuning_args.lora_dropout,
133
+ target_modules=finetuning_args.lora_target
134
+ )
135
+ model = get_peft_model(model, lora_config)
136
+
137
+ if model_args.checkpoint_dir is not None:
138
+ logger.info("Loaded fine-tuned model from checkpoint(s): {}".format(",".join(model_args.checkpoint_dir)))
139
+
140
+ return model
141
+
142
+
143
+ def load_pretrained(
144
+ model_args: ModelArguments,
145
+ finetuning_args: FinetuningArguments,
146
+ is_trainable: Optional[bool] = False,
147
+ stage: Optional[Literal["pt", "sft", "rm", "ppo"]] = "sft"
148
+ ) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
149
+ r"""
150
+ Loads pretrained model and tokenizer.
151
+
152
+ Support both training and inference.
153
+ """
154
+ if (not is_trainable) and model_args.checkpoint_dir is None:
155
+ logger.warning("Checkpoint is not found at evaluation, load the original model.")
156
+ finetuning_args = FinetuningArguments(finetuning_type="none")
157
+
158
+ assert stage in ["pt", "sft"] or finetuning_args.finetuning_type == "lora", \
159
+ "RM and PPO training can only be performed with the LoRA method."
160
+
161
+ config_kwargs = {
162
+ "trust_remote_code": True,
163
+ "cache_dir": model_args.cache_dir,
164
+ "revision": model_args.model_revision,
165
+ "use_auth_token": True if model_args.use_auth_token else None,
166
+ }
167
+
168
+ tokenizer = AutoTokenizer.from_pretrained(
169
+ model_args.model_name_or_path,
170
+ use_fast=model_args.use_fast_tokenizer,
171
+ padding_side="left",
172
+ **config_kwargs
173
+ )
174
+ tokenizer.pad_token_id = 0 if tokenizer.pad_token_id is None else tokenizer.pad_token_id # set as the <unk> token
175
+ tokenizer.pad_token_id = 0 if tokenizer.pad_token_id == 64000 else tokenizer.pad_token_id # for baichuan model (older version)
176
+
177
+ config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)
178
+ is_mergeable = True
179
+
180
+ # Quantization configurations (using bitsandbytes library).
181
+ if model_args.quantization_bit is not None:
182
+ if model_args.quantization_bit == 8:
183
+ require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
184
+ config_kwargs["load_in_8bit"] = True
185
+ config_kwargs["quantization_config"] = BitsAndBytesConfig(
186
+ load_in_8bit=True,
187
+ llm_int8_threshold=6.0
188
+ )
189
+ elif model_args.quantization_bit == 4:
190
+ require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
191
+ require_version("transformers>=4.30.1", "To fix: pip install transformers>=4.30.1")
192
+ require_version("accelerate>=0.20.3", "To fix: pip install accelerate>=0.20.3")
193
+ require_version("peft>=0.4.0.dev0", "To fix: pip install git+https://github.com/huggingface/peft.git")
194
+ config_kwargs["load_in_4bit"] = True
195
+ config_kwargs["quantization_config"] = BitsAndBytesConfig(
196
+ load_in_4bit=True,
197
+ bnb_4bit_compute_dtype=model_args.compute_dtype,
198
+ bnb_4bit_use_double_quant=model_args.double_quantization,
199
+ bnb_4bit_quant_type=model_args.quantization_type
200
+ )
201
+ is_mergeable = False
202
+ config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))}
203
+ logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
204
+
205
+ if not is_trainable: # `device_map=auto` should be used for inference only
206
+ config_kwargs["device_map"] = "auto"
207
+
208
+ # Load and prepare pretrained models (without valuehead).
209
+ model = AutoModelForCausalLM.from_pretrained(
210
+ model_args.model_name_or_path,
211
+ config=config,
212
+ torch_dtype=torch.bfloat16 if model_args.compute_dtype == torch.bfloat16 else torch.float16,
213
+ low_cpu_mem_usage=True,
214
+ **config_kwargs
215
+ )
216
+ model = prepare_model_for_training(model, finetuning_args.finetuning_type) if is_trainable else model
217
+ model = _init_adapter(model, model_args, finetuning_args, is_trainable, is_mergeable)
218
+
219
+ if stage == "rm" or stage == "ppo": # add value head
220
+ model = AutoModelForCausalLMWithValueHead.from_pretrained(model)
221
+
222
+ if stage == "rm" and model_args.checkpoint_dir is not None: # load valuehead weights to evaluate reward model
223
+ logger.warning("Only the last checkpoint containing valuehead will be loaded as the valuehead.")
224
+ if load_valuehead_params(model, model_args.checkpoint_dir[-1]):
225
+ model.v_head.load_state_dict({
226
+ "summary.weight": getattr(model, "reward_head_weight"),
227
+ "summary.bias": getattr(model, "reward_head_bias")
228
+ })
229
+
230
+ if stage == "ppo": # load reward model
231
+ assert is_trainable, "PPO stage cannot be performed at evaluation."
232
+ assert model_args.reward_model is not None, "Reward model is necessary for PPO training."
233
+ logger.info("Load reward model from {}".format(model_args.reward_model))
234
+ model.pretrained_model.load_adapter(model_args.reward_model, "reward", is_trainable=False)
235
+ assert load_valuehead_params(model, model_args.reward_model), "Reward model is not correctly loaded."
236
+
237
+ if not is_trainable:
238
+ model.requires_grad_(False) # fix all model params
239
+ model = model.half() if model_args.quantization_bit is None else model # cast from fp32 to fp16
240
+
241
+ print_trainable_params(model)
242
+
243
+ return model, tokenizer
244
+
245
+
246
+ def prepare_args(
247
+ stage: Literal["pt", "sft", "rm", "ppo"]
248
+ ) -> Tuple[ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments, FinetuningArguments]:
249
+
250
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments, FinetuningArguments))
251
+
252
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # Provide arguments with a json file.
253
+ model_args, data_args, training_args, finetuning_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
254
+ else:
255
+ model_args, data_args, training_args, finetuning_args = parser.parse_args_into_dataclasses()
256
+
257
+ # Setup logging
258
+ if training_args.should_log:
259
+ # The default of training_args.log_level is passive, so we set log level at info here to have that default.
260
+ transformers.utils.logging.set_verbosity_info()
261
+
262
+ log_level = training_args.get_process_log_level()
263
+ datasets.utils.logging.set_verbosity(log_level)
264
+ transformers.utils.logging.set_verbosity(log_level)
265
+ transformers.utils.logging.enable_default_handler()
266
+ transformers.utils.logging.enable_explicit_format()
267
+
268
+ # Check arguments (do not check finetuning_args since it may be loaded from checkpoints)
269
+ if stage != "sft" and training_args.predict_with_generate:
270
+ raise ValueError("`predict_with_generate` cannot be set as True at PT, RM and PPO stages.")
271
+
272
+ if training_args.do_train and training_args.predict_with_generate:
273
+ raise ValueError("`predict_with_generate` cannot be set as True while training.")
274
+
275
+ if training_args.do_predict and (not training_args.predict_with_generate):
276
+ raise ValueError("Please enable `predict_with_generate` to save model predictions.")
277
+
278
+ if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":
279
+ raise ValueError("Quantization is only compatible with the LoRA method.")
280
+
281
+ if model_args.quantization_bit is not None and (not training_args.do_train):
282
+ logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.")
283
+
284
+ if training_args.do_train and (not training_args.fp16):
285
+ logger.warning("We recommend enable fp16 mixed precision training.")
286
+
287
+ if data_args.prompt_template == "alpaca":
288
+ logger.warning("Please specify `prompt_template` if you are using other pre-trained models.")
289
+
290
+ if training_args.local_rank != -1 and training_args.ddp_find_unused_parameters is None:
291
+ logger.warning("`ddp_find_unused_parameters` needs to be set as False in DDP training.")
292
+ training_args.ddp_find_unused_parameters = False
293
+
294
+ training_args.optim = "adamw_torch" if training_args.optim == "adamw_hf" else training_args.optim # suppress warning
295
+
296
+ if model_args.quantization_bit is not None:
297
+ if training_args.fp16:
298
+ model_args.compute_dtype = torch.float16
299
+ elif training_args.bf16:
300
+ model_args.compute_dtype = torch.bfloat16
301
+ else:
302
+ model_args.compute_dtype = torch.float32
303
+
304
+ # Log on each process the small summary:
305
+ logger.info(
306
+ f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}\n"
307
+ + f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
308
+ )
309
+ logger.info(f"Training/evaluation parameters {training_args}")
310
+
311
+ # Set seed before initializing model.
312
+ transformers.set_seed(training_args.seed)
313
+
314
+ return model_args, data_args, training_args, finetuning_args
315
+
316
+
317
+ def prepare_infer_args() -> Tuple[ModelArguments, DataTrainingArguments, FinetuningArguments, GeneratingArguments]:
318
+
319
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, FinetuningArguments, GeneratingArguments))
320
+
321
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # Provide arguments with a json file.
322
+ model_args, data_args, finetuning_args, generating_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
323
+ else:
324
+ model_args, data_args, finetuning_args, generating_args = parser.parse_args_into_dataclasses()
325
+
326
+ if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":
327
+ raise ValueError("Quantization is only compatible with the LoRA method.")
328
+
329
+ if data_args.prompt_template == "alpaca":
330
+ logger.warning("Please specify `prompt_template` if you are using other pre-trained models.")
331
+
332
+ return model_args, data_args, finetuning_args, generating_args
333
+
334
+
335
+ def prepare_data(
336
+ model_args: ModelArguments,
337
+ data_args: DataTrainingArguments
338
+ ) -> Dataset:
339
+
340
+ def checksum(file_path, hash):
341
+ with open(file_path, "rb") as datafile:
342
+ binary_data = datafile.read()
343
+ sha1 = hashlib.sha1(binary_data).hexdigest()
344
+ if sha1 != hash:
345
+ logger.warning("Checksum failed for {}. It may vary depending on the platform.".format(file_path))
346
+
347
+ max_samples = data_args.max_samples
348
+ all_datasets: List[Dataset] = [] # support multiple datasets
349
+
350
+ for dataset_attr in data_args.dataset_list:
351
+
352
+ logger.info("Loading dataset {}...".format(dataset_attr))
353
+
354
+ if dataset_attr.load_from == "hf_hub":
355
+ raw_datasets = load_dataset(dataset_attr.dataset_name, cache_dir=model_args.cache_dir)
356
+ elif dataset_attr.load_from == "script":
357
+ raw_datasets = load_dataset(
358
+ os.path.join(data_args.dataset_dir, dataset_attr.dataset_name),
359
+ cache_dir=model_args.cache_dir
360
+ )
361
+ elif dataset_attr.load_from == "file":
362
+ data_file = os.path.join(data_args.dataset_dir, dataset_attr.file_name)
363
+
364
+ extension = dataset_attr.file_name.split(".")[-1]
365
+ if extension == "csv":
366
+ file_type = "csv"
367
+ elif extension == "json" or extension == "jsonl":
368
+ file_type = "json"
369
+ else:
370
+ file_type = "text"
371
+
372
+ if dataset_attr.file_sha1 is not None:
373
+ checksum(data_file, dataset_attr.file_sha1)
374
+ else:
375
+ logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json.")
376
+
377
+ raw_datasets = load_dataset(
378
+ file_type,
379
+ data_files=data_file,
380
+ cache_dir=model_args.cache_dir,
381
+ use_auth_token=True if model_args.use_auth_token else None
382
+ )
383
+ else:
384
+ raise NotImplementedError
385
+
386
+ dataset = raw_datasets[data_args.split]
387
+
388
+ if max_samples is not None:
389
+ max_samples_temp = min(len(dataset), max_samples)
390
+ dataset = dataset.select(range(max_samples_temp))
391
+
392
+ dummy_data = [None] * len(dataset)
393
+ for column_name, target_name in [
394
+ ("prompt_column", "prompt"),
395
+ ("query_column", "query"),
396
+ ("response_column", "response"),
397
+ ("history_column", "history")
398
+ ]: # every dataset will have 4 columns same as each other
399
+ if getattr(dataset_attr, column_name) != target_name:
400
+ if getattr(dataset_attr, column_name):
401
+ dataset = dataset.rename_column(getattr(dataset_attr, column_name), target_name)
402
+ else: # None or empty string
403
+ dataset = dataset.add_column(target_name, dummy_data)
404
+ all_datasets.append(dataset)
405
+
406
+ if len(data_args.dataset_list) == 1:
407
+ all_datasets = all_datasets[0]
408
+ else:
409
+ all_datasets = concatenate_datasets(all_datasets)
410
+
411
+ return all_datasets
412
+
413
+
414
+ def preprocess_data(
415
+ dataset: Dataset,
416
+ tokenizer: PreTrainedTokenizer,
417
+ data_args: DataTrainingArguments,
418
+ training_args: Seq2SeqTrainingArguments,
419
+ stage: Literal["pt", "sft", "rm", "ppo"]
420
+ ) -> Dataset:
421
+
422
+ column_names = list(dataset.column_names)
423
+ prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
424
+ prompt_template = Template(data_args.prompt_template)
425
+
426
+ # support question with a single answer or multiple answers
427
+ def get_dialog(examples):
428
+ for i in range(len(examples["prompt"])):
429
+ if examples["prompt"][i] and examples["response"][i]:
430
+ query, answer = examples["prompt"][i], examples["response"][i]
431
+ query = query + "\n" + examples["query"][i] if examples["query"][i] else query
432
+ dialog = prompt_template.get_dialog(query, answer, examples["history"][i], prefix)
433
+ yield dialog
434
+
435
+ def preprocess_pretrain_dataset(examples):
436
+ # build grouped texts with format `[BOS] X1 X2 X3 ...` (without [EOS])
437
+ text_ids = tokenizer(examples["prompt"], add_special_tokens=False)["input_ids"]
438
+ concatenated_ids = list(chain(*text_ids))
439
+ total_length = len(concatenated_ids)
440
+ block_size = data_args.max_source_length - 1
441
+ # we drop the small remainder, and if the total_length < block_size, we exclude this batch
442
+ total_length = (total_length // block_size) * block_size
443
+ # split by chunks of max_source_length
444
+ result = [[tokenizer.bos_token_id] + concatenated_ids[i: i + block_size]
445
+ for i in range(0, total_length, block_size)]
446
+ return {
447
+ "input_ids": result,
448
+ "labels": result.copy()
449
+ }
450
+
451
+ def preprocess_supervised_dataset(examples):
452
+ # build inputs with format `X [BOS] Y [EOS]` and labels with format `[IGNORE] ... [IGNORE] Y [EOS]`
453
+ # for input with history, we build multiple input-label pairs just like:
454
+ # https://github.com/lm-sys/FastChat/blob/f17c092f64840fa6354ed52789dccb2daa793d0b/fastchat/train/train.py#L112
455
+ model_inputs = {"input_ids": [], "labels": []}
456
+ for dialog in get_dialog(examples):
457
+ input_ids, labels = [], []
458
+
459
+ for i in range(len(dialog) // 2):
460
+ source_ids = tokenizer.encode(text=dialog[2*i], add_special_tokens=False)
461
+ target_ids = tokenizer.encode(text=dialog[2*i+1], add_special_tokens=False)
462
+ input_ids += source_ids + [tokenizer.bos_token_id] + target_ids + [tokenizer.eos_token_id]
463
+ labels += [IGNORE_INDEX] * (len(source_ids) + 1) + target_ids + [tokenizer.eos_token_id]
464
+
465
+ model_inputs["input_ids"].append(input_ids[:data_args.max_source_length + data_args.max_target_length])
466
+ model_inputs["labels"].append(labels[:data_args.max_source_length + data_args.max_target_length])
467
+ return model_inputs
468
+
469
+ def preprocess_unsupervised_dataset(examples):
470
+ # build inputs with format `X [BOS]` and labels with format `Y [BOS]`
471
+ model_inputs = {"input_ids": [], "labels": []}
472
+ for dialog in get_dialog(examples):
473
+ prompt, answer = "".join(dialog[:-1]), dialog[-1]
474
+
475
+ source_ids = tokenizer.encode(text=prompt, add_special_tokens=False)
476
+ target_ids = tokenizer.encode(text=answer, add_special_tokens=False)
477
+
478
+ if len(source_ids) > data_args.max_source_length - 1: # bos token
479
+ source_ids = source_ids[:data_args.max_source_length - 1]
480
+ if len(target_ids) > data_args.max_target_length - 1: # bos token
481
+ target_ids = target_ids[:data_args.max_target_length - 1]
482
+
483
+ input_ids = source_ids + [tokenizer.bos_token_id]
484
+ labels = target_ids + [tokenizer.bos_token_id]
485
+
486
+ model_inputs["input_ids"].append(input_ids)
487
+ model_inputs["labels"].append(labels)
488
+ return model_inputs
489
+
490
+ def preprocess_pairwise_dataset(examples):
491
+ # build input pairs with format `X [BOS] Y1 [EOS]` and `X [BOS] Y2 [EOS]`
492
+ model_inputs = {"accept_ids": [], "reject_ids": []}
493
+ for dialog in get_dialog(examples):
494
+ prompt, answer = "".join(dialog[:-1]), dialog[-1]
495
+
496
+ source_ids = tokenizer.encode(text=prompt, add_special_tokens=False)
497
+ accept_ids = tokenizer.encode(text=answer[0], add_special_tokens=False)
498
+ reject_ids = tokenizer.encode(text=answer[1], add_special_tokens=False)
499
+
500
+ if len(source_ids) > data_args.max_source_length - 1: # bos token
501
+ source_ids = source_ids[:data_args.max_source_length - 1]
502
+ if len(accept_ids) > data_args.max_target_length - 1: # eos token
503
+ accept_ids = accept_ids[:data_args.max_target_length - 1]
504
+ if len(reject_ids) > data_args.max_target_length - 1: # eos token
505
+ reject_ids = reject_ids[:data_args.max_target_length - 1]
506
+
507
+ accept_ids = source_ids + [tokenizer.bos_token_id] + accept_ids + [tokenizer.eos_token_id]
508
+ reject_ids = source_ids + [tokenizer.bos_token_id] + reject_ids + [tokenizer.eos_token_id]
509
+
510
+ model_inputs["accept_ids"].append(accept_ids)
511
+ model_inputs["reject_ids"].append(reject_ids)
512
+ return model_inputs
513
+
514
+ def print_supervised_dataset_example(example):
515
+ print("input_ids:\n{}".format(example["input_ids"]))
516
+ print("inputs:\n{}".format(tokenizer.decode(example["input_ids"])))
517
+ print("label_ids:\n{}".format(example["labels"]))
518
+ print("labels:\n{}".format(
519
+ tokenizer.decode([d if d != IGNORE_INDEX else tokenizer.pad_token_id for d in example["labels"]]))
520
+ )
521
+
522
+ def print_pairwise_dataset_example(example):
523
+ print("accept_ids:\n{}".format(example["accept_ids"]))
524
+ print("accepts:\n{}".format(tokenizer.decode(example["accept_ids"])))
525
+ print("reject_ids:\n{}".format(example["reject_ids"]))
526
+ print("rejects:\n{}".format(tokenizer.decode(example["reject_ids"])))
527
+
528
+ def print_unsupervised_dataset_example(example):
529
+ print("input_ids:\n{}".format(example["input_ids"]))
530
+ print("inputs:\n{}".format(tokenizer.decode(example["input_ids"])))
531
+
532
+ if stage == "pt":
533
+ preprocess_function = preprocess_pretrain_dataset
534
+ elif stage == "sft":
535
+ preprocess_function = preprocess_unsupervised_dataset \
536
+ if training_args.predict_with_generate else preprocess_supervised_dataset
537
+ elif stage == "rm":
538
+ preprocess_function = preprocess_pairwise_dataset
539
+ elif stage == "ppo":
540
+ preprocess_function = preprocess_unsupervised_dataset
541
+
542
+ with training_args.main_process_first(desc="dataset map pre-processing"):
543
+ dataset = dataset.map(
544
+ preprocess_function,
545
+ batched=True,
546
+ num_proc=data_args.preprocessing_num_workers,
547
+ remove_columns=column_names,
548
+ load_from_cache_file=not data_args.overwrite_cache,
549
+ desc="Running tokenizer on dataset"
550
+ )
551
+
552
+ if stage == "pt":
553
+ print_unsupervised_dataset_example(dataset[0])
554
+ elif stage == "sft":
555
+ print_supervised_dataset_example(dataset[0])
556
+ elif stage == "rm":
557
+ print_pairwise_dataset_example(dataset[0])
558
+ elif stage == "ppo":
559
+ print_unsupervised_dataset_example(dataset[0])
560
+
561
+ return dataset
src/utils/config.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+ from typing import Any, Dict, List, Literal, Optional
5
+ from dataclasses import asdict, dataclass, field
6
+
7
+
8
+ @dataclass
9
+ class DatasetAttr:
10
+
11
+ load_from: str
12
+ dataset_name: Optional[str] = None
13
+ file_name: Optional[str] = None
14
+ file_sha1: Optional[str] = None
15
+
16
+ def __repr__(self) -> str:
17
+ if self.dataset_name is not None:
18
+ return self.dataset_name
19
+ else:
20
+ return self.file_name
21
+
22
+ def __post_init__(self):
23
+ self.prompt_column = "instruction"
24
+ self.query_column = "input"
25
+ self.response_column = "output"
26
+ self.history_column = None
27
+
28
+
29
+ @dataclass
30
+ class ModelArguments:
31
+ """
32
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune.
33
+ """
34
+ model_name_or_path: str = field(
35
+ metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models."}
36
+ )
37
+ cache_dir: Optional[str] = field(
38
+ default=None,
39
+ metadata={"help": "Where to store the pretrained models downloaded from huggingface.co."}
40
+ )
41
+ use_fast_tokenizer: Optional[bool] = field(
42
+ default=False,
43
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}
44
+ )
45
+ use_auth_token: Optional[bool] = field(
46
+ default=False,
47
+ metadata={"help": "Will use the token generated when running `huggingface-cli login`."}
48
+ )
49
+ model_revision: Optional[str] = field(
50
+ default="main",
51
+ metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}
52
+ )
53
+ quantization_bit: Optional[int] = field(
54
+ default=None,
55
+ metadata={"help": "The number of bits to quantize the model."}
56
+ )
57
+ quantization_type: Optional[Literal["fp4", "nf4"]] = field(
58
+ default="nf4",
59
+ metadata={"help": "Quantization data type to use in int4 training."}
60
+ )
61
+ double_quantization: Optional[bool] = field(
62
+ default=True,
63
+ metadata={"help": "Whether to use double quantization in int4 training or not."}
64
+ )
65
+ compute_dtype: Optional[torch.dtype] = field(
66
+ default=None,
67
+ metadata={"help": "Used in quantization configs. Do not specify this argument manually."}
68
+ )
69
+ checkpoint_dir: Optional[str] = field(
70
+ default=None,
71
+ metadata={"help": "Path to the directory(s) containing the delta model checkpoints as well as the configurations."}
72
+ )
73
+ reward_model: Optional[str] = field(
74
+ default=None,
75
+ metadata={"help": "Path to the directory containing the checkpoints of the reward model."}
76
+ )
77
+ resume_lora_training: Optional[bool] = field(
78
+ default=True,
79
+ metadata={"help": "Whether to resume training from the last LoRA weights or create new weights after merging them."}
80
+ )
81
+ plot_loss: Optional[bool] = field(
82
+ default=False,
83
+ metadata={"help": "Whether to plot the training loss after fine-tuning or not."}
84
+ )
85
+
86
+ def __post_init__(self):
87
+ if self.checkpoint_dir is not None: # support merging multiple lora weights
88
+ self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")]
89
+
90
+ if self.quantization_bit is not None:
91
+ assert self.quantization_bit in [4, 8], "We only accept 4-bit or 8-bit quantization."
92
+
93
+ @dataclass
94
+ class DataTrainingArguments:
95
+ """
96
+ Arguments pertaining to what data we are going to input our model for training and evaluation.
97
+ """
98
+ dataset: Optional[str] = field(
99
+ default="alpaca_zh",
100
+ metadata={"help": "The name of provided dataset(s) to use. Use comma to separate multiple datasets."}
101
+ )
102
+ dataset_dir: Optional[str] = field(
103
+ default="data",
104
+ metadata={"help": "The name of the folder containing datasets."}
105
+ )
106
+ split: Optional[str] = field(
107
+ default="train",
108
+ metadata={"help": "Which dataset split to use for training and evaluation."}
109
+ )
110
+ overwrite_cache: Optional[bool] = field(
111
+ default=False,
112
+ metadata={"help": "Overwrite the cached training and evaluation sets."}
113
+ )
114
+ preprocessing_num_workers: Optional[int] = field(
115
+ default=None,
116
+ metadata={"help": "The number of processes to use for the preprocessing."}
117
+ )
118
+ max_source_length: Optional[int] = field(
119
+ default=512,
120
+ metadata={"help": "The maximum total input sequence length after tokenization."}
121
+ )
122
+ max_target_length: Optional[int] = field(
123
+ default=512,
124
+ metadata={"help": "The maximum total output sequence length after tokenization."}
125
+ )
126
+ max_samples: Optional[int] = field(
127
+ default=None,
128
+ metadata={"help": "For debugging purposes, truncate the number of examples for each dataset."}
129
+ )
130
+ eval_num_beams: Optional[int] = field(
131
+ default=None,
132
+ metadata={"help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`"}
133
+ )
134
+ ignore_pad_token_for_loss: Optional[bool] = field(
135
+ default=True,
136
+ metadata={"help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."}
137
+ )
138
+ source_prefix: Optional[str] = field(
139
+ default=None,
140
+ metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
141
+ )
142
+ dev_ratio: Optional[float] = field(
143
+ default=0,
144
+ metadata={"help": "Proportion of the dataset to include in the development set, should be between 0.0 and 1.0."}
145
+ )
146
+ prompt_template: Optional[str] = field(
147
+ default="alpaca",
148
+ metadata={"help": "Which template to use for constructing prompts in training and inference."}
149
+ )
150
+
151
+ def __post_init__(self): # support mixing multiple datasets
152
+ dataset_names = [ds.strip() for ds in self.dataset.split(",")]
153
+ with open(os.path.join(self.dataset_dir, "dataset_info.json"), "r") as f:
154
+ dataset_info = json.load(f)
155
+
156
+ self.dataset_list: List[DatasetAttr] = []
157
+ for name in dataset_names:
158
+ if name not in dataset_info:
159
+ raise ValueError("Undefined dataset {} in dataset_info.json.".format(name))
160
+
161
+ if "hf_hub_url" in dataset_info[name]:
162
+ dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"])
163
+ elif "script_url" in dataset_info[name]:
164
+ dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"])
165
+ else:
166
+ dataset_attr = DatasetAttr(
167
+ "file",
168
+ file_name=dataset_info[name]["file_name"],
169
+ file_sha1=dataset_info[name].get("file_sha1", None)
170
+ )
171
+
172
+ if "columns" in dataset_info[name]:
173
+ dataset_attr.prompt_column = dataset_info[name]["columns"].get("prompt", None)
174
+ dataset_attr.query_column = dataset_info[name]["columns"].get("query", None)
175
+ dataset_attr.response_column = dataset_info[name]["columns"].get("response", None)
176
+ dataset_attr.history_column = dataset_info[name]["columns"].get("history", None)
177
+
178
+ self.dataset_list.append(dataset_attr)
179
+
180
+
181
+ @dataclass
182
+ class FinetuningArguments:
183
+ """
184
+ Arguments pertaining to which techniques we are going to fine-tuning with.
185
+ """
186
+ finetuning_type: Optional[Literal["none", "freeze", "lora", "full"]] = field(
187
+ default="lora",
188
+ metadata={"help": "Which fine-tuning method to use."}
189
+ )
190
+ num_layer_trainable: Optional[int] = field(
191
+ default=3,
192
+ metadata={"help": "Number of trainable layers for Freeze fine-tuning."}
193
+ )
194
+ name_module_trainable: Optional[Literal["mlp", "self_attn", "self_attention"]] = field(
195
+ default="mlp",
196
+ metadata={"help": "Name of trainable modules for Freeze fine-tuning. \
197
+ LLaMA choices: [\"mlp\", \"self_attn\"], \
198
+ BLOOM choices: [\"mlp\", \"self_attention\"], \
199
+ Baichuan choices: [\"mlp\", \"self_attn\"]"}
200
+ )
201
+ lora_rank: Optional[int] = field(
202
+ default=8,
203
+ metadata={"help": "The intrinsic dimension for LoRA fine-tuning."}
204
+ )
205
+ lora_alpha: Optional[float] = field(
206
+ default=32.0,
207
+ metadata={"help": "The scale factor for LoRA fine-tuning (similar with the learning rate)."}
208
+ )
209
+ lora_dropout: Optional[float] = field(
210
+ default=0.1,
211
+ metadata={"help": "Dropout rate for the LoRA fine-tuning."}
212
+ )
213
+ lora_target: Optional[str] = field(
214
+ default="q_proj,v_proj",
215
+ metadata={"help": "Name(s) of target modules to apply LoRA. Use comma to separate multiple modules. \
216
+ LLaMA choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \
217
+ BLOOM choices: [\"query_key_value\", \"self_attention.dense\", \"mlp.dense\"], \
218
+ Baichuan choices: [\"W_pack\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"]"}
219
+ )
220
+
221
+ def __post_init__(self):
222
+ if isinstance(self.lora_target, str): # support custom target modules/layers of LoRA
223
+ self.lora_target = [target.strip() for target in self.lora_target.split(",")]
224
+
225
+ if self.num_layer_trainable > 0: # fine-tuning the last n layers if num_layer_trainable > 0
226
+ trainable_layer_ids = [27 - k for k in range(self.num_layer_trainable)]
227
+ else: # fine-tuning the first n layers if num_layer_trainable < 0
228
+ trainable_layer_ids = [k for k in range(-self.num_layer_trainable)]
229
+
230
+ self.trainable_layers = ["layers.{:d}.{}".format(idx, self.name_module_trainable) for idx in trainable_layer_ids]
231
+
232
+ assert self.finetuning_type in ["none", "freeze", "lora", "full"], "Invalid fine-tuning method."
233
+
234
+ def save_to_json(self, json_path: str):
235
+ """Saves the content of this instance in JSON format inside `json_path`."""
236
+ json_string = json.dumps(asdict(self), indent=2, sort_keys=True) + "\n"
237
+ with open(json_path, "w", encoding="utf-8") as f:
238
+ f.write(json_string)
239
+
240
+ @classmethod
241
+ def load_from_json(cls, json_path: str):
242
+ """Creates an instance from the content of `json_path`."""
243
+ with open(json_path, "r", encoding="utf-8") as f:
244
+ text = f.read()
245
+ return cls(**json.loads(text))
246
+
247
+
248
+ @dataclass
249
+ class GeneratingArguments:
250
+ """
251
+ Arguments pertaining to specify the decoding parameters.
252
+ """
253
+ do_sample: Optional[bool] = field(
254
+ default=True,
255
+ metadata={"help": "Whether or not to use sampling, use greedy decoding otherwise."}
256
+ )
257
+ temperature: Optional[float] = field(
258
+ default=0.95,
259
+ metadata={"help": "The value used to modulate the next token probabilities."}
260
+ )
261
+ top_p: Optional[float] = field(
262
+ default=0.7,
263
+ metadata={"help": "The smallest set of most probable tokens with probabilities that add up to top_p or higher are kept."}
264
+ )
265
+ top_k: Optional[int] = field(
266
+ default=50,
267
+ metadata={"help": "The number of highest probability vocabulary tokens to keep for top-k filtering."}
268
+ )
269
+ num_beams: Optional[int] = field(
270
+ default=1,
271
+ metadata={"help": "Number of beams for beam search. 1 means no beam search."}
272
+ )
273
+ max_new_tokens: Optional[int] = field(
274
+ default=512,
275
+ metadata={"help": "The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt."}
276
+ )
277
+ repetition_penalty: Optional[float] = field(
278
+ default=1.0,
279
+ metadata={"help": "The parameter for repetition penalty. 1.0 means no penalty."}
280
+ )
281
+
282
+ def to_dict(self) -> Dict[str, Any]:
283
+ return asdict(self)
src/utils/data_collator.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from typing import Dict, Optional, Sequence, Union
4
+
5
+ from transformers import DataCollatorWithPadding, BatchEncoding
6
+ from transformers.tokenization_utils import PreTrainedTokenizer
7
+
8
+ from .other import IGNORE_INDEX
9
+
10
+
11
+ class DynamicDataCollatorWithPadding(DataCollatorWithPadding):
12
+ r"""
13
+ Inherits DataCollatorWithPadding. It is capable of dynamically padding for batched data.
14
+ """
15
+ def __init__(
16
+ self,
17
+ tokenizer: PreTrainedTokenizer,
18
+ ignore_pad_token_for_loss: Optional[bool] = False
19
+ ):
20
+ super().__init__(tokenizer, padding=True)
21
+ self.label_pad_token_id = IGNORE_INDEX if ignore_pad_token_for_loss else tokenizer.pad_token_id
22
+
23
+ def get_attention_masks(self, input_ids: torch.Tensor, device: torch.device) -> torch.Tensor:
24
+ r"""
25
+ Generates attention masks for left-padded sequences.
26
+ """
27
+ batch_size, seq_length = input_ids.size()
28
+ attention_mask = torch.ones((batch_size, seq_length), device=device)
29
+ for i, seq in enumerate(input_ids):
30
+ attention_mask[i, :(seq != self.tokenizer.pad_token_id).nonzero()[0].item()] = 0 # padding
31
+ attention_mask = attention_mask.bool()
32
+ return attention_mask
33
+
34
+ def __call__(self, features: Sequence[Dict[str, Union[torch.Tensor, Sequence[int]]]]) -> BatchEncoding:
35
+ r"""
36
+ Pads batched data to the longest sequence in the batch.
37
+
38
+ We adopt left-padding in both training and evaluation.
39
+ """
40
+ if isinstance(features[0]["input_ids"], torch.Tensor):
41
+ input_ids = [feature["input_ids"].clone().detach().flip(0) for feature in features]
42
+ else:
43
+ input_ids = [torch.tensor(feature["input_ids"]).flip(0) for feature in features]
44
+
45
+ if "labels" in features[0]:
46
+ if isinstance(features[0]["labels"], torch.Tensor):
47
+ labels = [feature["labels"].clone().detach().flip(0) for feature in features]
48
+ else:
49
+ labels = [torch.tensor(feature["labels"]).flip(0) for feature in features]
50
+ input_ids = input_ids + labels # pad them to the same length
51
+
52
+ input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id).flip(-1)
53
+
54
+ batch = {}
55
+
56
+ if "labels" in features[0]:
57
+ input_ids, labels = input_ids.split(len(features), dim=0)
58
+ labels = torch.where(labels != self.tokenizer.pad_token_id, labels, self.label_pad_token_id)
59
+ batch["labels"] = labels
60
+
61
+ batch["input_ids"] = input_ids
62
+ batch["attention_mask"] = self.get_attention_masks(input_ids, device=input_ids.device)
63
+
64
+ return BatchEncoding(batch)
src/utils/other.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import json
4
+ import torch
5
+ import logging
6
+ from typing import Dict, List, Optional
7
+
8
+ from transformers.trainer import TRAINER_STATE_NAME
9
+ from transformers.modeling_utils import PreTrainedModel
10
+ from transformers.generation.utils import LogitsProcessorList
11
+ from transformers.generation.logits_process import LogitsProcessor
12
+
13
+ from peft.utils import WEIGHTS_NAME
14
+
15
+
16
+ IGNORE_INDEX = -100
17
+ VALUE_HEAD_FILE_NAME = "value_head.bin"
18
+ FINETUNING_ARGS_NAME = "finetuning_args.json"
19
+
20
+
21
+ def get_logger(name: str) -> logging.Logger:
22
+ return logging.getLogger(name)
23
+
24
+
25
+ logging.basicConfig(
26
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
27
+ datefmt="%m/%d/%Y %H:%M:%S",
28
+ level=logging.INFO,
29
+ handlers=[logging.StreamHandler(sys.stdout)]
30
+ )
31
+
32
+
33
+ logger = get_logger(__name__)
34
+
35
+
36
+ class AverageMeter:
37
+ r"""
38
+ Computes and stores the average and current value.
39
+ """
40
+ def __init__(self):
41
+ self.reset()
42
+
43
+ def reset(self):
44
+ self.val = 0
45
+ self.avg = 0
46
+ self.sum = 0
47
+ self.count = 0
48
+
49
+ def update(self, val, n=1):
50
+ self.val = val
51
+ self.sum += val * n
52
+ self.count += n
53
+ self.avg = self.sum / self.count
54
+
55
+
56
+ # Avoid runtime error in model.generate(do_sample=True).
57
+ class InvalidScoreLogitsProcessor(LogitsProcessor):
58
+
59
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
60
+ if torch.isnan(scores).any() or torch.isinf(scores).any():
61
+ scores.zero_()
62
+ scores[..., 0] = 1.0
63
+ return scores
64
+
65
+
66
+ def get_logits_processor() -> LogitsProcessorList:
67
+ logits_processor = LogitsProcessorList()
68
+ logits_processor.append(InvalidScoreLogitsProcessor())
69
+ return logits_processor
70
+
71
+
72
+ # Includes: (1) cast the layernorm in fp32 (2) make output embedding layer require grads (3) upcast the lm_head to fp32
73
+ # Inspired by: https://github.com/huggingface/peft/blob/c0209c35abbf88c63aa267800d98a8e212ed0a42/src/peft/utils/other.py#L35
74
+ def prepare_model_for_training(
75
+ model: PreTrainedModel,
76
+ finetuning_type: str,
77
+ output_embedding_layer_name: Optional[str] = "lm_head",
78
+ use_gradient_checkpointing: Optional[bool] = True,
79
+ layer_norm_names: Optional[List[str]] = ["norm", "ln_f"] # for LLaMA and BLOOM setting
80
+ ) -> PreTrainedModel:
81
+
82
+ for name, param in model.named_parameters():
83
+ if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names):
84
+ param.data = param.data.to(torch.float32)
85
+
86
+ if use_gradient_checkpointing:
87
+ if hasattr(model, "enable_input_require_grads"):
88
+ model.enable_input_require_grads()
89
+ else:
90
+ def make_inputs_require_grad(module, input, output):
91
+ output.requires_grad_(True)
92
+ model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
93
+
94
+ model.gradient_checkpointing_enable()
95
+ model.config.use_cache = False # turn off when gradient checkpointing is enabled
96
+
97
+ if finetuning_type != "full" and hasattr(model, output_embedding_layer_name):
98
+ output_embedding_layer: torch.nn.Linear = getattr(model, output_embedding_layer_name)
99
+ input_dtype = output_embedding_layer.weight.dtype
100
+
101
+ class CastOutputToFloat(torch.nn.Sequential):
102
+
103
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
104
+ return super().forward(x.to(input_dtype)).to(torch.float32)
105
+
106
+ setattr(model, output_embedding_layer_name, CastOutputToFloat(output_embedding_layer))
107
+
108
+ return model
109
+
110
+
111
+ def print_trainable_params(model: torch.nn.Module) -> None:
112
+ trainable_params, all_param = 0, 0
113
+ for param in model.parameters():
114
+ num_params = param.numel()
115
+ # if using DS Zero 3 and the weights are initialized empty
116
+ if num_params == 0 and hasattr(param, "ds_numel"):
117
+ num_params = param.ds_numel
118
+ all_param += num_params
119
+ if param.requires_grad:
120
+ trainable_params += num_params
121
+ print("trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format(
122
+ trainable_params, all_param, 100 * trainable_params / all_param))
123
+
124
+
125
+ def get_state_dict(model: torch.nn.Module) -> Dict[str, torch.Tensor]: # get state dict containing trainable parameters
126
+ state_dict = model.state_dict()
127
+ filtered_state_dict = {}
128
+
129
+ for k, v in model.named_parameters():
130
+ if v.requires_grad:
131
+ filtered_state_dict[k] = state_dict[k].cpu().clone().detach()
132
+
133
+ return filtered_state_dict
134
+
135
+
136
+ def load_trainable_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -> bool:
137
+ weights_file = os.path.join(checkpoint_dir, WEIGHTS_NAME)
138
+ if not os.path.exists(weights_file):
139
+ logger.warning("Provided path ({}) does not contain pre-trained weights.".format(checkpoint_dir))
140
+ return False
141
+ model_state_dict = torch.load(weights_file, map_location="cpu")
142
+ model.load_state_dict(model_state_dict, strict=False) # skip missing keys
143
+ return True
144
+
145
+
146
+ def load_valuehead_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -> bool:
147
+ valuehead_file = os.path.join(checkpoint_dir, VALUE_HEAD_FILE_NAME)
148
+ if not os.path.exists(valuehead_file):
149
+ logger.warning("Provided path ({}) does not contain valuehead weights.".format(checkpoint_dir))
150
+ return False
151
+ valuehead_state_dict = torch.load(valuehead_file, map_location="cpu")
152
+ model.register_buffer("reward_head_weight", valuehead_state_dict["summary.weight"])
153
+ model.register_buffer("reward_head_bias", valuehead_state_dict["summary.bias"])
154
+ model.register_buffer("default_head_weight", torch.zeros_like(valuehead_state_dict["summary.weight"]))
155
+ model.register_buffer("default_head_bias", torch.zeros_like(valuehead_state_dict["summary.bias"]))
156
+ return True
157
+
158
+
159
+ def smooth(scalars: List[float], weight: Optional[float] = 0.9) -> List[float]:
160
+ r"""
161
+ EMA implementation according to TensorBoard.
162
+ """
163
+ last = scalars[0]
164
+ smoothed = list()
165
+ for next_val in scalars:
166
+ smoothed_val = last * weight + (1 - weight) * next_val
167
+ smoothed.append(smoothed_val)
168
+ last = smoothed_val
169
+ return smoothed
170
+
171
+
172
+ def plot_loss(save_dictionary: os.PathLike, keys: Optional[List[str]] = ["loss"]) -> None:
173
+ import matplotlib.pyplot as plt
174
+ with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), "r", encoding="utf-8") as f:
175
+ data = json.load(f)
176
+
177
+ for key in keys:
178
+ steps, metrics = [], []
179
+ for i in range(len(data["log_history"])):
180
+ if key in data["log_history"][i]:
181
+ steps.append(data["log_history"][i]["step"])
182
+ metrics.append(data["log_history"][i][key])
183
+
184
+ if len(metrics) == 0:
185
+ logger.warning(f"No metric {key} to plot.")
186
+ continue
187
+
188
+ plt.figure()
189
+ plt.plot(steps, metrics, alpha=0.4, label="original")
190
+ plt.plot(steps, smooth(metrics), label="smoothed")
191
+ plt.title("training {} of {}".format(key, save_dictionary))
192
+ plt.xlabel("step")
193
+ plt.ylabel(key)
194
+ plt.legend()
195
+ plt.savefig(os.path.join(save_dictionary, "training_{}.png".format(key)), format="png", dpi=100)
196
+ print("Figure saved:", os.path.join(save_dictionary, "training_{}.png".format(key)))
src/utils/pairwise.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from typing import Dict, Sequence, Tuple, Union
4
+
5
+ from .data_collator import DynamicDataCollatorWithPadding
6
+
7
+ from .peft_trainer import PeftTrainer
8
+
9
+ from .other import get_logger
10
+
11
+ logger = get_logger(__name__)
12
+
13
+
14
+ def compute_accuracy(eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) -> Dict[str, float]:
15
+ preds, _ = eval_preds
16
+ preds = np.array(preds)
17
+ return {"accuracy": (preds[:, 0] > preds[:, 1]).sum() / len(preds)}
18
+
19
+
20
+ class PairwiseDataCollatorWithPadding(DynamicDataCollatorWithPadding):
21
+ r"""
22
+ Data collator for pairwise data.
23
+ """
24
+
25
+ def __call__(self, features: Sequence[Dict[str, Union[torch.Tensor, Sequence[int]]]]) -> Dict[str, torch.Tensor]:
26
+ r"""
27
+ Pads batched data to the longest sequence in the batch.
28
+
29
+ We generate 2 * n examples where the first n examples represent chosen examples and
30
+ the last n examples represent rejected examples.
31
+ """
32
+ features = [{"input_ids": feature[key]} for key in ("accept_ids", "reject_ids") for feature in features]
33
+ return super().__call__(features)
34
+
35
+
36
+ class PairwisePeftTrainer(PeftTrainer):
37
+ r"""
38
+ Inherits PeftTrainer to compute pairwise loss.
39
+ """
40
+
41
+ def __init__(self, *args, **kwargs):
42
+ super().__init__(*args, **kwargs)
43
+ self.can_return_loss = True # override property to return eval_loss
44
+
45
+ def compute_loss(self, model, inputs, return_outputs=False):
46
+ r"""
47
+ Computes pairwise loss. The first n examples are chosen and the last n examples are rejected.
48
+
49
+ We use score on the EOS token to represent reward of the whole sentence.
50
+
51
+ Subclass and override to inject custom behavior. It should not be directly used by external scripts.
52
+ """
53
+ batch_size = inputs["input_ids"].size(0) // 2
54
+ _, _, values = model(**inputs)
55
+ r_accept, r_reject = values[:, -1].split(batch_size, dim=0)
56
+ loss = -torch.log(torch.sigmoid(r_accept - r_reject)).mean()
57
+ return (loss, torch.stack((r_accept, r_reject), dim=-1)) if return_outputs else loss
src/utils/peft_trainer.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import time
4
+ import torch
5
+ from typing import Dict, Optional
6
+ from datetime import timedelta
7
+
8
+ from transformers import (
9
+ Seq2SeqTrainer,
10
+ TrainerCallback,
11
+ TrainerControl,
12
+ TrainerState,
13
+ TrainingArguments
14
+ )
15
+
16
+ from transformers.trainer import TRAINING_ARGS_NAME
17
+ from transformers.modeling_utils import unwrap_model
18
+
19
+ from peft.utils.other import WEIGHTS_NAME
20
+
21
+ from .config import FinetuningArguments
22
+
23
+ from .other import (
24
+ get_logger,
25
+ get_state_dict,
26
+ load_trainable_params,
27
+ load_valuehead_params,
28
+ FINETUNING_ARGS_NAME,
29
+ VALUE_HEAD_FILE_NAME
30
+ )
31
+
32
+
33
+ logger = get_logger(__name__)
34
+
35
+
36
+ class LogCallback(TrainerCallback):
37
+ r"""
38
+ TrainerCallback includes the state function during training, for more details refer to the TrainerCallback class.
39
+ The on_log function primarily collects process parameters during training, such as training loss, learning rate,
40
+ and training epochs, as well as progress parameters like the current percentage progress and estimated remaining
41
+ time. Every time a log is triggered, a new record is appended to the file "messages.log" for dynamic visualization
42
+ purposes.
43
+ """
44
+
45
+ def __init__(self):
46
+ self.start_time = time.time()
47
+
48
+ def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs) -> None:
49
+ r"""
50
+ Event called after logging the last logs.
51
+ """
52
+ if "loss" not in state.log_history[-1]:
53
+ return
54
+ cur_time = time.time()
55
+ cur_steps = state.log_history[-1].get("step")
56
+ elapsed_time = cur_time - self.start_time
57
+ avg_time_per_step = elapsed_time / cur_steps if cur_steps != 0 else 0
58
+ remaining_steps = state.max_steps - cur_steps
59
+ remaining_time = remaining_steps * avg_time_per_step
60
+ log_dict = {
61
+ "current_steps": cur_steps,
62
+ "total_steps": state.max_steps,
63
+ "loss": state.log_history[-1].get("loss", None),
64
+ "reward": state.log_history[-1].get("reward", None),
65
+ "learning_rate": state.log_history[-1].get("learning_rate", None),
66
+ "epoch": state.log_history[-1].get("epoch", None),
67
+ "percentage": round(cur_steps / state.max_steps * 100, 2) if state.max_steps != 0 else 100,
68
+ "elapsed_time": str(timedelta(seconds=int(elapsed_time))),
69
+ "remaining_time": str(timedelta(seconds=int(remaining_time)))
70
+ }
71
+ os.makedirs(args.output_dir, exist_ok=True)
72
+ with open(os.path.join(args.output_dir, "trainer_log.jsonl"), "a") as f:
73
+ f.write(json.dumps(log_dict) + "\n")
74
+
75
+
76
+ class PeftTrainer(Seq2SeqTrainer):
77
+ r"""
78
+ Inherits Seq2SeqTrainer to support parameter-efficient checkpoints.
79
+ """
80
+
81
+ def __init__(self, finetuning_args: FinetuningArguments, **kwargs):
82
+ super().__init__(**kwargs)
83
+ self.finetuning_args = finetuning_args
84
+ if self.is_world_process_zero() and os.path.exists(os.path.join(self.args.output_dir, "trainer_log.jsonl")):
85
+ logger.warning("Previous log file in this folder will be deleted.")
86
+ os.remove(os.path.join(self.args.output_dir, "trainer_log.jsonl"))
87
+
88
+ def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, torch.Tensor]] = None) -> None:
89
+ r"""
90
+ Saves trainable parameters as model checkpoint.
91
+
92
+ This function will only be executed at the process zero.
93
+
94
+ Subclass and override to inject custom behavior. It should not be directly used by external scripts.
95
+ """
96
+ output_dir = output_dir if output_dir is not None else self.args.output_dir
97
+ os.makedirs(output_dir, exist_ok=True)
98
+ logger.info(f"Saving model checkpoint to {output_dir}")
99
+ model = unwrap_model(self.model)
100
+
101
+ if hasattr(model, "pretrained_model"): # for models with valuehead
102
+ backbone_model = getattr(model, "pretrained_model")
103
+ else:
104
+ backbone_model = model
105
+
106
+ if hasattr(backbone_model, "peft_config"): # peft methods
107
+ backbone_model.save_pretrained(output_dir, state_dict=get_state_dict(backbone_model)) # save lora weights
108
+ else:
109
+ torch.save(get_state_dict(backbone_model), os.path.join(output_dir, WEIGHTS_NAME)) # save trainable weights
110
+
111
+ if hasattr(model, "v_head"): # save valuehead weights
112
+ torch.save(get_state_dict(getattr(model, "v_head")), os.path.join(output_dir, VALUE_HEAD_FILE_NAME))
113
+
114
+ with open(os.path.join(output_dir, TRAINING_ARGS_NAME), "w", encoding="utf-8") as f:
115
+ f.write(self.args.to_json_string() + "\n")
116
+ self.finetuning_args.save_to_json(os.path.join(output_dir, FINETUNING_ARGS_NAME))
117
+
118
+ def _load_best_model(self):
119
+ r"""
120
+ Loads trainable parameters from model checkpoint.
121
+
122
+ Subclass and override to inject custom behavior. It should not be directly used by external scripts.
123
+ """
124
+ logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).")
125
+ model = unwrap_model(self.model)
126
+ if hasattr(model, "peft_config"): # peft methods
127
+ model.load_adapter(self.state.best_model_checkpoint, getattr(model, "active_adapter"))
128
+ else:
129
+ load_trainable_params(model, self.state.best_model_checkpoint)
130
+
131
+ if hasattr(model, "v_head"):
132
+ load_valuehead_params(model, self.state.best_model_checkpoint)
src/utils/ppo.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import torch
4
+ from tqdm import tqdm
5
+ from typing import Callable, Dict, List, Literal, Optional, Tuple
6
+
7
+ from transformers import Seq2SeqTrainingArguments, TrainerState
8
+ from transformers.modeling_utils import PreTrainedModel
9
+
10
+ from trl import PPOTrainer, AutoModelForCausalLMWithValueHead
11
+ from trl.core import LengthSampler
12
+
13
+ from .peft_trainer import PeftTrainer, LogCallback
14
+
15
+ from .config import FinetuningArguments
16
+
17
+ from .other import (
18
+ AverageMeter,
19
+ get_logger,
20
+ get_logits_processor
21
+ )
22
+
23
+
24
+ logger = get_logger(__name__)
25
+
26
+
27
+ def replace_model(model: AutoModelForCausalLMWithValueHead, target: Literal["default", "reward"]) -> None:
28
+ if target == "reward": # save original head temporarily
29
+ valuehead_state_dict = model.v_head.state_dict()
30
+
31
+ setattr(model, "origin_head_weight", valuehead_state_dict["summary.weight"])
32
+ setattr(model, "origin_head_bias", valuehead_state_dict["summary.bias"])
33
+
34
+ model.pretrained_model.set_adapter(target) # set the LoRA adapter to be active
35
+ model.v_head.load_state_dict({
36
+ "summary.weight": getattr(model, "{}_head_weight".format(target)),
37
+ "summary.bias": getattr(model, "{}_head_bias".format(target))
38
+ })
39
+
40
+
41
+ def cast_layernorm_dtype(
42
+ model: AutoModelForCausalLMWithValueHead,
43
+ layer_norm_names: List[str] = ["norm", "ln_f"], # for LLaMA and BLOOM setting
44
+ layer_norm_params: Optional[Dict[str, torch.Tensor]] = None
45
+ ) -> Tuple[AutoModelForCausalLMWithValueHead, Dict[str, torch.Tensor]]:
46
+
47
+ layer_norm_state_dict = {}
48
+
49
+ for name, param in model.named_parameters():
50
+ if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names):
51
+ if layer_norm_params is not None:
52
+ param.data = layer_norm_params[name] # restore float32 weights
53
+ else:
54
+ layer_norm_state_dict[name] = param.data.detach().clone() # store float32 weights for stability
55
+ param.data = param.data.to(torch.float16)
56
+
57
+ return model, layer_norm_state_dict
58
+
59
+
60
+ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
61
+ r"""
62
+ Inherits PPOTrainer.
63
+ """
64
+
65
+ def __init__(
66
+ self,
67
+ training_args: Seq2SeqTrainingArguments,
68
+ finetuning_args: FinetuningArguments,
69
+ callbacks: List[LogCallback],
70
+ **kwargs
71
+ ):
72
+ PPOTrainer.__init__(self, **kwargs)
73
+ self.args = training_args
74
+ self.finetuning_args = finetuning_args
75
+ self.log_callback = callbacks[0]
76
+ self.state = TrainerState()
77
+ self.data_collator = self.accelerator.prepare(kwargs["data_collator"]) # override the data collator of PPOTrainer
78
+
79
+ def ppo_train(self, max_target_length: int) -> None:
80
+ r"""
81
+ Implements training loop for the PPO stage, like _inner_training_loop() in Huggingface's Trainer.
82
+ """
83
+ total_train_batch_size = self.config.batch_size * self.config.gradient_accumulation_steps * self.args.world_size
84
+ len_dataloader = len(self.dataloader)
85
+ num_steps_per_epoch = max(len_dataloader // self.config.gradient_accumulation_steps, 1)
86
+ num_examples = len(self.dataset)
87
+ num_train_epochs = self.args.num_train_epochs
88
+ max_steps = math.ceil(num_train_epochs * num_steps_per_epoch)
89
+
90
+ self.state.max_steps = max_steps
91
+ self.state.num_train_epochs = num_train_epochs
92
+ self.state.is_local_process_zero = self.is_local_process_zero()
93
+ self.state.is_world_process_zero = self.is_world_process_zero()
94
+
95
+ if self.is_world_process_zero():
96
+ logger.info("***** Running training *****")
97
+ logger.info(f" Num examples = {num_examples}")
98
+ logger.info(f" Num Epochs = {num_train_epochs}")
99
+ logger.info(f" Instantaneous batch size per device = {self.config.batch_size}")
100
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}")
101
+ logger.info(f" Gradient Accumulation steps = {self.config.gradient_accumulation_steps}")
102
+ logger.info(f" Total optimization steps = {max_steps}")
103
+ logger.info(f" Number of trainable parameters = {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}")
104
+
105
+ # Keyword arguments for `model.generate`
106
+ gen_kwargs = {
107
+ "top_k": 0.0,
108
+ "top_p": 1.0,
109
+ "do_sample": True,
110
+ "pad_token_id": self.tokenizer.pad_token_id,
111
+ "eos_token_id": self.tokenizer.eos_token_id,
112
+ "logits_processor": get_logits_processor()
113
+ }
114
+ output_length_sampler = LengthSampler(max_target_length // 2, max_target_length)
115
+ unwrapped_model: PreTrainedModel = self.accelerator.unwrap_model(self.model)
116
+
117
+ dataiter = iter(self.dataloader)
118
+ steps_trained = 0
119
+ loss_meter = AverageMeter()
120
+ reward_meter = AverageMeter()
121
+
122
+ for step in tqdm(range(max_steps), disable=not self.is_world_process_zero()):
123
+
124
+ for _ in range(self.config.gradient_accumulation_steps):
125
+
126
+ batch = next(dataiter)
127
+ steps_trained += 1
128
+
129
+ unwrapped_model.gradient_checkpointing_disable()
130
+ unwrapped_model.config.use_cache = True
131
+
132
+ # Get response from model
133
+ query_tensors: torch.Tensor = batch["input_ids"]
134
+ response_tensors = self.generate(batch, length_sampler=output_length_sampler, return_prompt=False, **gen_kwargs)
135
+
136
+ queries: List[torch.Tensor] = []
137
+ responses: List[torch.Tensor] = []
138
+ for i in range(len(query_tensors)):
139
+ query_length = (query_tensors[i] != self.tokenizer.pad_token_id).nonzero()[0]
140
+ response_length = (response_tensors[i] != self.tokenizer.pad_token_id).nonzero()[-1] + 1
141
+ queries.append(query_tensors[i, query_length:]) # remove padding from left
142
+ if response_length < 2: # make response have at least 2 tokens
143
+ responses.append(response_tensors.new_empty(2).fill_(self.tokenizer.eos_token_id))
144
+ else:
145
+ responses.append(response_tensors[i, :response_length]) # remove padding from right
146
+
147
+ # Compute rewards
148
+ replace_model(unwrapped_model, target="reward")
149
+ _, _, values = self.model(**self.prepare_model_inputs(queries, responses))
150
+ rewards = [reward for reward in values[:, -1].to(torch.float32)] # use float32 type
151
+ replace_model(unwrapped_model, target="default") # make sure the model is default at the end
152
+
153
+ # Run PPO step
154
+ unwrapped_model.gradient_checkpointing_enable()
155
+ unwrapped_model.config.use_cache = False
156
+
157
+ stats = self.step(queries, responses, rewards)
158
+
159
+ loss_meter.update(stats["ppo/loss/total"], n=len(rewards))
160
+ reward_meter.update(torch.stack(rewards).mean().item(), n=len(rewards))
161
+
162
+ if steps_trained == len_dataloader:
163
+ dataiter = iter(self.dataloader)
164
+ steps_trained = 0
165
+
166
+ if self.is_world_process_zero() and (step+1) % self.args.logging_steps == 0:
167
+ logs = {
168
+ "loss": round(loss_meter.avg, 4),
169
+ "reward": round(reward_meter.avg, 4),
170
+ "learning_rate": stats["ppo/learning_rate"],
171
+ "epoch": round(step / num_steps_per_epoch, 2)
172
+ }
173
+ print(logs)
174
+ logs["step"] = step
175
+ self.state.log_history.append(logs)
176
+ self.log_callback.on_log(self.args, self.state, None)
177
+ loss_meter.reset()
178
+ reward_meter.reset()
179
+
180
+ if (step+1) % self.args.save_steps == 0: # save checkpoint
181
+ self.save_model(os.path.join(self.args.output_dir, f"checkpoint-{step+1}"))
182
+
183
+ @torch.no_grad()
184
+ def generate(
185
+ self,
186
+ inputs: Dict[str, torch.Tensor],
187
+ length_sampler: Optional[Callable] = None,
188
+ return_prompt: Optional[bool] = True,
189
+ **generation_kwargs,
190
+ ) -> torch.Tensor:
191
+ r"""
192
+ Generates model's responses given queries.
193
+
194
+ Subclass and override to inject custom behavior.
195
+ """
196
+ self.model, layer_norm_params = cast_layernorm_dtype(self.model)
197
+
198
+ if length_sampler is not None:
199
+ generation_kwargs["max_new_tokens"] = length_sampler()
200
+
201
+ unwrapped_model = self.accelerator.unwrap_model(self.model)
202
+
203
+ response = unwrapped_model.generate(**inputs, **generation_kwargs)
204
+
205
+ # Temporary hack to ensure the generation config is not initialized for each iteration of the evaluation loop
206
+ # Inspired by: https://github.com/huggingface/transformers/blob/v4.28.1/src/transformers/trainer_seq2seq.py#L273
207
+ if unwrapped_model.pretrained_model.generation_config._from_model_config:
208
+ unwrapped_model.pretrained_model.generation_config._from_model_config = False
209
+
210
+ self.model, _ = cast_layernorm_dtype(self.model, layer_norm_params)
211
+
212
+ if not return_prompt and not self.is_encoder_decoder:
213
+ return response[:, inputs["input_ids"].size(1):]
214
+ return response
215
+
216
+ def save_model(self, output_dir: Optional[str] = None) -> None:
217
+ r"""
218
+ Saves model checkpoint.
219
+
220
+ Subclass and override to inject custom behavior.
221
+ """
222
+ if self.args.should_save:
223
+ self._save(output_dir)
src/utils/seq2seq.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import numpy as np
4
+ from dataclasses import dataclass
5
+ from typing import Dict, List, Sequence, Tuple, Union
6
+
7
+ from transformers.trainer import PredictionOutput
8
+ from transformers.tokenization_utils import PreTrainedTokenizer
9
+
10
+ import jieba
11
+ from rouge_chinese import Rouge
12
+ from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
13
+
14
+ from .peft_trainer import PeftTrainer
15
+
16
+ from .other import get_logger, IGNORE_INDEX
17
+
18
+
19
+ logger = get_logger(__name__)
20
+
21
+
22
+ @dataclass
23
+ class ComputeMetrics:
24
+ r"""
25
+ Wraps the tokenizer into metric functions, used in Seq2SeqPeftTrainer.
26
+
27
+ Borrowed from: https://github.com/THUDM/ChatGLM-6B/blob/0c2806fea82683349194e21996dd6b3acc3c265b/ptuning/main.py#L307
28
+ """
29
+
30
+ tokenizer: PreTrainedTokenizer
31
+
32
+ def __call__(self, eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) -> Dict[str, float]:
33
+ r"""
34
+ Uses the model predictions to compute metrics.
35
+ """
36
+ preds, labels = eval_preds
37
+ if isinstance(preds, tuple):
38
+ preds = preds[0]
39
+ # Replace IGNORE_INDEX in the labels with pad_token_id as we cannot decode them if ignore_pad_token_for_loss=True.
40
+ preds = np.where(preds != IGNORE_INDEX, preds, self.tokenizer.pad_token_id)
41
+ labels = np.where(labels != IGNORE_INDEX, labels, self.tokenizer.pad_token_id)
42
+
43
+ score_dict = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []}
44
+ for pred, label in zip(preds, labels):
45
+ pred = pred[(pred == self.tokenizer.bos_token_id).nonzero()[0][0]:] # remove the query
46
+ hypothesis = list(jieba.cut(self.tokenizer.decode(pred, skip_special_tokens=True)))
47
+ reference = list(jieba.cut(self.tokenizer.decode(label, skip_special_tokens=True)))
48
+
49
+ if len(" ".join(hypothesis).split()) == 0:
50
+ result = {"rouge-1": {"f": 0.0}, "rouge-2": {"f": 0.0}, "rouge-l": {"f": 0.0}}
51
+ else:
52
+ rouge = Rouge()
53
+ scores = rouge.get_scores(" ".join(hypothesis), " ".join(reference))
54
+ result = scores[0]
55
+
56
+ for k, v in result.items():
57
+ score_dict[k].append(round(v["f"] * 100, 4))
58
+
59
+ bleu_score = sentence_bleu([list(label)], list(pred), smoothing_function=SmoothingFunction().method3)
60
+ score_dict["bleu-4"].append(round(bleu_score * 100, 4))
61
+
62
+ return {k: float(np.mean(v)) for k, v in score_dict.items()}
63
+
64
+
65
+ class Seq2SeqPeftTrainer(PeftTrainer):
66
+ r"""
67
+ Inherits PeftTrainer to compute generative metrics such as BLEU and ROUGE.
68
+ """
69
+
70
+ def save_predictions(
71
+ self,
72
+ predict_results: PredictionOutput,
73
+ tokenizer: PreTrainedTokenizer
74
+ ) -> None:
75
+ r"""
76
+ Saves model predictions to `output_dir`.
77
+
78
+ A custom behavior that not contained in Seq2SeqTrainer.
79
+ """
80
+ if not self.is_world_process_zero():
81
+ return
82
+
83
+ preds = np.where(predict_results.predictions != IGNORE_INDEX, predict_results.predictions, self.tokenizer.pad_token_id)
84
+ labels = np.where(predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids, self.tokenizer.pad_token_id)
85
+
86
+ preds = [pred[(pred == self.tokenizer.bos_token_id).nonzero()[0][0]:] for pred in preds] # remove the queries
87
+ preds = [tokenizer.decode(pred, skip_special_tokens=True).strip() for pred in preds]
88
+ labels = [tokenizer.decode(label, skip_special_tokens=True).strip() for label in labels]
89
+
90
+ output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl")
91
+ logger.info(f"Saving prediction results to {output_prediction_file}")
92
+ with open(output_prediction_file, "w", encoding="utf-8") as writer:
93
+ res: List[str] = []
94
+ for pred, label in zip(preds, labels):
95
+ res.append(json.dumps({"label": label, "predict": pred}, ensure_ascii=False))
96
+ writer.write("\n".join(res))
src/utils/template.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional
2
+ from dataclasses import dataclass
3
+
4
+
5
+ @dataclass
6
+ class Template:
7
+
8
+ name: str
9
+
10
+ def __post_init__(self):
11
+
12
+ if self.name == "vanilla":
13
+ r"""
14
+ Supports language model inference without histories.
15
+ """
16
+ self._register_template(
17
+ prefix="",
18
+ prompt="{query}",
19
+ sep="",
20
+ use_history=False
21
+ )
22
+
23
+ elif self.name == "alpaca":
24
+ r"""
25
+ Supports: https://huggingface.co/tatsu-lab/alpaca-7b-wdiff
26
+ https://github.com/ymcui/Chinese-LLaMA-Alpaca
27
+ """
28
+ self._register_template(
29
+ prefix="Below is an instruction that describes a task. "
30
+ "Write a response that appropriately completes the request.\n\n",
31
+ prompt="### Instruction:\n{query}\n\n### Response:\n",
32
+ sep="\n\n",
33
+ use_history=True
34
+ )
35
+
36
+ elif self.name == "vicuna":
37
+ r"""
38
+ Supports: https://huggingface.co/lmsys/vicuna-7b-delta-v1.1
39
+ https://huggingface.co/lmsys/vicuna-13b-delta-v1.1
40
+ """
41
+ self._register_template(
42
+ prefix="A chat between a curious user and an artificial intelligence assistant. "
43
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
44
+ prompt="USER: {query} ASSISTANT: ",
45
+ sep="</s>",
46
+ use_history=True
47
+ )
48
+
49
+ elif self.name == "belle":
50
+ r"""
51
+ Supports: https://huggingface.co/BelleGroup/BELLE-LLaMA-EXT-13B
52
+ """
53
+ self._register_template(
54
+ prefix="",
55
+ prompt="Human: {query}\n\nBelle: ",
56
+ sep="\n\n",
57
+ use_history=True
58
+ )
59
+
60
+ elif self.name == "linly":
61
+ r"""
62
+ Supports: https://github.com/CVI-SZU/Linly
63
+ """
64
+ self._register_template(
65
+ prefix="",
66
+ prompt="User: {query}\nBot: ",
67
+ sep="\n",
68
+ use_history=True
69
+ )
70
+
71
+ elif self.name == "billa":
72
+ r"""
73
+ Supports: https://github.com/Neutralzz/BiLLa
74
+ """
75
+ self._register_template(
76
+ prefix="",
77
+ prompt="Human: {query}\nAssistant: ",
78
+ sep="\n",
79
+ use_history=True
80
+ )
81
+
82
+ elif self.name == "ziya":
83
+ r"""
84
+ Supports: https://huggingface.co/IDEA-CCNL/Ziya-LLaMA-13B-v1
85
+ """
86
+ self._register_template(
87
+ prefix="",
88
+ prompt="<human>:{query}\n<bot>:",
89
+ sep="\n",
90
+ use_history=True
91
+ )
92
+
93
+ elif self.name == "aquila":
94
+ r"""
95
+ Supports: https://huggingface.co/qhduan/aquilachat-7b
96
+ """
97
+ self._register_template(
98
+ prefix="A chat between a curious human and an artificial intelligence assistant. "
99
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
100
+ prompt="Human: {query}\nAssistant: ",
101
+ sep="###",
102
+ use_history=True
103
+ )
104
+
105
+ else:
106
+ raise ValueError("Template {} does not exist.".format(self.name))
107
+
108
+ def get_prompt(self, query: str, history: Optional[list] = None, prefix: Optional[str] = "") -> str:
109
+ r"""
110
+ Returns a string containing prompt without response.
111
+ """
112
+ return "".join(self._format_example(query, history, prefix))
113
+
114
+ def get_dialog(self, query: str, resp: str, history: Optional[list] = None, prefix: Optional[str] = "") -> List[str]:
115
+ r"""
116
+ Returns a list containing 2 * n elements where the 2k-th is a query and the (2k+1)-th is a response.
117
+ """
118
+ return self._format_example(query, history, prefix) + [resp]
119
+
120
+ def _register_template(self, prefix: str, prompt: str, sep: str, use_history: Optional[bool] = True) -> None:
121
+ self.prefix = prefix
122
+ self.prompt = prompt
123
+ self.sep = sep
124
+ self.use_history = use_history
125
+
126
+ def _format_example(self, query: str, history: Optional[list] = None, prefix: Optional[str] = "") -> List[str]:
127
+ prefix = prefix if prefix else self.prefix
128
+ history = history if (history and self.use_history) else []
129
+ history = history + [(query, "<dummy>")]
130
+ convs = []
131
+ for turn_idx, (user_query, bot_resp) in enumerate(history):
132
+ if turn_idx == 0:
133
+ convs.append(prefix + self.prompt.format(query=user_query))
134
+ convs.append(bot_resp)
135
+ else:
136
+ convs.append(self.sep + self.prompt.format(query=user_query))
137
+ convs.append(bot_resp)
138
+ return convs[:-1] # drop last
src/web_demo.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Implements user interface in browser for fine-tuned models.
3
+ # Usage: python web_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint
4
+
5
+
6
+ import mdtex2html
7
+ import gradio as gr
8
+
9
+ from threading import Thread
10
+ from utils import (
11
+ Template,
12
+ load_pretrained,
13
+ prepare_infer_args,
14
+ get_logits_processor
15
+ )
16
+
17
+ from transformers import TextIteratorStreamer
18
+ from transformers.utils.versions import require_version
19
+
20
+
21
+ require_version("gradio>=3.30.0", "To fix: pip install gradio>=3.30.0")
22
+
23
+
24
+ model_args, data_args, finetuning_args, generating_args = prepare_infer_args()
25
+ model, tokenizer = load_pretrained(model_args, finetuning_args)
26
+
27
+ prompt_template = Template(data_args.prompt_template)
28
+
29
+
30
+ def postprocess(self, y):
31
+ r"""
32
+ Overrides Chatbot.postprocess
33
+ """
34
+ if y is None:
35
+ return []
36
+ for i, (message, response) in enumerate(y):
37
+ y[i] = (
38
+ None if message is None else mdtex2html.convert((message)),
39
+ None if response is None else mdtex2html.convert(response),
40
+ )
41
+ return y
42
+
43
+
44
+ gr.Chatbot.postprocess = postprocess
45
+
46
+
47
+ def parse_text(text): # copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT
48
+ lines = text.split("\n")
49
+ lines = [line for line in lines if line != ""]
50
+ count = 0
51
+ for i, line in enumerate(lines):
52
+ if "```" in line:
53
+ count += 1
54
+ items = line.split("`")
55
+ if count % 2 == 1:
56
+ lines[i] = "<pre><code class=\"language-{}\">".format(items[-1])
57
+ else:
58
+ lines[i] = "<br /></code></pre>"
59
+ else:
60
+ if i > 0:
61
+ if count % 2 == 1:
62
+ line = line.replace("`", "\`")
63
+ line = line.replace("<", "&lt;")
64
+ line = line.replace(">", "&gt;")
65
+ line = line.replace(" ", "&nbsp;")
66
+ line = line.replace("*", "&ast;")
67
+ line = line.replace("_", "&lowbar;")
68
+ line = line.replace("-", "&#45;")
69
+ line = line.replace(".", "&#46;")
70
+ line = line.replace("!", "&#33;")
71
+ line = line.replace("(", "&#40;")
72
+ line = line.replace(")", "&#41;")
73
+ line = line.replace("$", "&#36;")
74
+ lines[i] = "<br />" + line
75
+ text = "".join(lines)
76
+ return text
77
+
78
+
79
+ def predict(query, chatbot, max_length, top_p, temperature, history):
80
+ chatbot.append((parse_text(query), ""))
81
+
82
+ input_ids = tokenizer([prompt_template.get_prompt(query, history)], return_tensors="pt")["input_ids"]
83
+ input_ids = input_ids.to(model.device)
84
+
85
+ streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
86
+
87
+ gen_kwargs = {
88
+ "input_ids": input_ids,
89
+ "do_sample": generating_args.do_sample,
90
+ "top_p": top_p,
91
+ "temperature": temperature,
92
+ "num_beams": generating_args.num_beams,
93
+ "max_length": max_length,
94
+ "repetition_penalty": generating_args.repetition_penalty,
95
+ "logits_processor": get_logits_processor(),
96
+ "streamer": streamer
97
+ }
98
+
99
+ thread = Thread(target=model.generate, kwargs=gen_kwargs)
100
+ thread.start()
101
+
102
+ response = ""
103
+ for new_text in streamer:
104
+ response += new_text
105
+ new_history = history + [(query, response)]
106
+ chatbot[-1] = (parse_text(query), parse_text(response))
107
+ yield chatbot, new_history
108
+
109
+
110
+ def reset_user_input():
111
+ return gr.update(value="")
112
+
113
+
114
+ def reset_state():
115
+ return [], []
116
+
117
+
118
+ with gr.Blocks() as demo:
119
+
120
+ gr.HTML("""
121
+ <h1 align="center">
122
+ <a href="https://chato.cn/" target="_blank">
123
+ 百姓AI助手
124
+ </a>
125
+ </h1>
126
+ """)
127
+
128
+ chatbot = gr.Chatbot()
129
+
130
+ with gr.Row():
131
+ with gr.Column(scale=4):
132
+ with gr.Column(scale=12):
133
+ user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10).style(container=False)
134
+ with gr.Column(min_width=32, scale=1):
135
+ submitBtn = gr.Button("Submit", variant="primary")
136
+
137
+ with gr.Column(scale=1):
138
+ emptyBtn = gr.Button("Clear History")
139
+ max_length = gr.Slider(0, 2048, value=1024, step=1.0, label="Maximum length", interactive=True)
140
+ top_p = gr.Slider(0, 1, value=generating_args.top_p, step=0.01, label="Top P", interactive=True)
141
+ temperature = gr.Slider(0, 1.5, value=generating_args.temperature, step=0.01, label="Temperature", interactive=True)
142
+
143
+ history = gr.State([])
144
+
145
+ submitBtn.click(predict, [user_input, chatbot, max_length, top_p, temperature, history], [chatbot, history], show_progress=True)
146
+ submitBtn.click(reset_user_input, [], [user_input])
147
+
148
+ emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True)
149
+
150
+ demo.queue().launch(server_name="0.0.0.0", share=True, inbrowser=True)