|
import gc |
|
import os |
|
from dataclasses import dataclass, field |
|
from typing import Dict, List, Optional |
|
|
|
import torch |
|
from datasets import Dataset, DatasetInfo, builder, load_dataset |
|
from peft import AutoPeftModelForCausalLM, PeftModelForCausalLM |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, TrainingArguments |
|
from vllm import LLM, SamplingParams |
|
from vllm.model_executor.parallel_utils.parallel_state import destroy_model_parallel |
|
|
|
from trl import DPOTrainer |
|
|
|
|
|
builder.has_sufficient_disk_space = lambda needed_bytes, directory=".": True |
|
|
|
|
|
@dataclass |
|
class ScriptArguments: |
|
output_dir: Optional[str] = field( |
|
default="compare_results", |
|
metadata={"help": "output folder"}, |
|
) |
|
num_gpus: Optional[int] = field(default=1) |
|
model_name: Optional[str] = field(default="EleutherAI/pythia-410m", metadata={"help": "the model name"}) |
|
revision: Optional[str] = field(default="main", metadata={"help": "the model revision"}) |
|
tokenizer_name: Optional[str] = field(default=None, metadata={"help": "the tokenizer name"}) |
|
dataset_name: Optional[str] = field( |
|
default="arianhosseini/openai_summarize_unlabelled", metadata={"help": "the dataset name"} |
|
) |
|
dataset_prompt_field: Optional[str] = field(default="query") |
|
train_split: Optional[str] = field(default="train[:20]", metadata={"help": "the dataset name"}) |
|
batch_size: Optional[int] = field(default=4) |
|
max_prompt_length: Optional[int] = field(default=512, metadata={"help": "Input sequence length"}) |
|
|
|
temperature: Optional[float] = field(default=0.7, metadata={"help": "Gen temperature"}) |
|
top_p: Optional[float] = field(default=1.0, metadata={"help": "Gen temperature"}) |
|
max_new_tokens: Optional[int] = field(default=48, metadata={"help": "max new tokens"}) |
|
dtype: Optional[str] = field(default="auto") |
|
|
|
lora_model: Optional[bool] = field(default=False) |
|
base_model_name: Optional[str] = field(default=None, metadata={"help": "the model name"}) |
|
base_model_revision: Optional[str] = field(default=None) |
|
|
|
|
|
def prepare_vllm_model(script_args): |
|
if script_args.tokenizer_name is not None: |
|
tokenizer_name = script_args.tokenizer_name |
|
else: |
|
tokenizer_name = script_args.model_name |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) |
|
|
|
if tokenizer_name.startswith("EleutherAI"): |
|
tokenizer.add_special_tokens({"pad_token": "[PAD]"}) |
|
|
|
tokenizer.padding_side = "left" |
|
|
|
if script_args.lora_model: |
|
|
|
if script_args.base_model_name is not None: |
|
base_model = AutoModelForCausalLM.from_pretrained( |
|
script_args.base_model_name, revision=script_args.base_model_revision |
|
) |
|
|
|
model = PeftModelForCausalLM.from_pretrained( |
|
base_model, script_args.model_name, revision=script_args.revision, device_map="cpu" |
|
) |
|
else: |
|
model = AutoPeftModelForCausalLM.from_pretrained( |
|
script_args.model_name, revision=script_args.revision, device_map="cpu" |
|
) |
|
merged = model.merge_and_unload() |
|
model_save_path = f"/home/toolkit/trl_results/{script_args.model_name}_merged/{script_args.revision}" |
|
merged.save_pretrained(model_save_path) |
|
del model |
|
del merged |
|
model_name = model_save_path |
|
revision = None |
|
else: |
|
model_name = script_args.model_name |
|
revision = script_args.revision |
|
|
|
llm = LLM( |
|
model=model_name, |
|
revision=revision, |
|
dtype=script_args.dtype, |
|
tokenizer=tokenizer_name, |
|
tensor_parallel_size=script_args.num_gpus, |
|
trust_remote_code=True, |
|
) |
|
llm.set_tokenizer(tokenizer) |
|
|
|
return llm, tokenizer |
|
|
|
|
|
def strip_prompt(examples): |
|
examples["prompt"] = [prompt.strip() for prompt in examples["prompt"]] |
|
|
|
return examples |
|
|
|
|
|
def generate_vllm(script_args): |
|
llm, _ = prepare_vllm_model(script_args) |
|
|
|
dataset = load_dataset(script_args.dataset_name, split=script_args.train_split) |
|
|
|
prompts = dataset[script_args.dataset_prompt_field] |
|
|
|
sampling_params = SamplingParams( |
|
temperature=script_args.temperature, |
|
max_tokens=script_args.max_new_tokens, |
|
top_p=script_args.top_p, |
|
n=2, |
|
) |
|
|
|
generations = llm.generate(prompts, sampling_params) |
|
|
|
print(f"generated {len(generations)} samples") |
|
|
|
def dataset_generator(): |
|
for gen in generations: |
|
if len(gen.outputs) == 2: |
|
yield { |
|
"prompt": gen.prompt, |
|
"chosen": gen.outputs[0].text, |
|
"rejected": gen.outputs[1].text, |
|
} |
|
else: |
|
print("skipping gen, only 1 output") |
|
|
|
ds_info = DatasetInfo( |
|
f"{script_args.dataset_name} split {script_args.train_split} prompts used to generate with {script_args.model_name}" |
|
f" temp {script_args.temperature} top_p {script_args.top_p} " |
|
) |
|
generated_dataset = Dataset.from_generator(dataset_generator, info=ds_info) |
|
|
|
destroy_model_parallel() |
|
del llm.llm_engine.driver_worker |
|
del llm |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
torch.distributed.destroy_process_group() |
|
|
|
return generated_dataset |
|
|
|
|
|
def relabel(script_args, dataset): |
|
torch_dtype = script_args.dtype if script_args.dtype in ["auto", None] else getattr(torch, script_args.dtype) |
|
|
|
if script_args.base_model_name is not None: |
|
base_model = AutoModelForCausalLM.from_pretrained( |
|
script_args.base_model_name, |
|
revision=script_args.base_model_revision, |
|
) |
|
|
|
model = PeftModelForCausalLM.from_pretrained( |
|
base_model, |
|
script_args.model_name, |
|
revision=script_args.revision, |
|
torch_dtype=torch_dtype, |
|
) |
|
else: |
|
model = AutoPeftModelForCausalLM.from_pretrained( |
|
script_args.model_name, |
|
revision=script_args.revision, |
|
torch_dtype=torch_dtype, |
|
) |
|
|
|
if script_args.tokenizer_name is not None: |
|
tokenizer_name = script_args.tokenizer_name |
|
else: |
|
tokenizer_name = script_args.model_name |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) |
|
|
|
if tokenizer_name.startswith("EleutherAI"): |
|
tokenizer.add_special_tokens({"pad_token": "[PAD]"}) |
|
|
|
training_args = TrainingArguments(per_device_eval_batch_size=int(script_args.batch_size), output_dir=".") |
|
|
|
dpo_trainer = DPOTrainer( |
|
model=model, |
|
tokenizer=tokenizer, |
|
args=training_args, |
|
max_length=script_args.max_new_tokens + script_args.max_prompt_length, |
|
max_target_length=script_args.max_new_tokens, |
|
max_prompt_length=script_args.max_prompt_length, |
|
) |
|
|
|
def relabel_with_preds(batch: Dict[str, List]): |
|
relabel_batch = { |
|
"prompt": [], |
|
"chosen": [], |
|
"rejected": [], |
|
"pred_chosen": [], |
|
"pred_rejected": [], |
|
} |
|
for prompt, chosen, rejected, pred_chosen, pred_rejected in zip( |
|
batch["prompt"], |
|
batch["chosen"], |
|
batch["rejected"], |
|
batch["pred_chosen"], |
|
batch["pred_rejected"], |
|
): |
|
relabel_batch["prompt"].append(prompt) |
|
if pred_chosen >= pred_rejected: |
|
relabel_batch["chosen"].append(chosen) |
|
relabel_batch["rejected"].append(rejected) |
|
relabel_batch["pred_chosen"].append(pred_chosen) |
|
relabel_batch["pred_rejected"].append(pred_rejected) |
|
else: |
|
relabel_batch["chosen"].append(rejected) |
|
relabel_batch["rejected"].append(chosen) |
|
relabel_batch["pred_chosen"].append(pred_rejected) |
|
relabel_batch["pred_rejected"].append(pred_chosen) |
|
|
|
return relabel_batch |
|
|
|
dpo_trainer.accelerator.print("Prediction") |
|
preds, _, metrics = dpo_trainer.predict(dataset) |
|
( |
|
chosen_rewards, |
|
rejected_rewards, |
|
policy_chosen_logps, |
|
policy_rejected_logps, |
|
reference_chosen_logps, |
|
reference_rejected_logps, |
|
) = preds |
|
dpo_trainer.accelerator.print(f"metrics {metrics}") |
|
|
|
if dpo_trainer.accelerator.is_local_main_process: |
|
print("Relabelling Dataset and Saving") |
|
dataset = dataset.add_column("pred_chosen", chosen_rewards) |
|
dataset = dataset.add_column("pred_rejected", rejected_rewards) |
|
|
|
relabel_dataset = dataset.map( |
|
relabel_with_preds, |
|
batched=True, |
|
) |
|
|
|
description = f"{script_args.dataset_name} relabelled with {script_args.model_name}" |
|
relabel_dataset._info.description = description |
|
relabel_dataset.push_to_hub(os.path.basename(script_args.output_dir), split="train") |
|
|
|
|
|
def generate_relabel_args_dict(args_dict): |
|
parser = HfArgumentParser(ScriptArguments) |
|
script_args = parser.parse_dict(args_dict)[0] |
|
dataset = generate_vllm(script_args) |
|
relabel(script_args, dataset) |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = HfArgumentParser(ScriptArguments) |
|
script_args = parser.parse_args_into_dataclasses()[0] |
|
|
|
generate_vllm(script_args) |
|
|