pythia410m-sft-tldr / code /generate_and_eval.py
mnoukhov's picture
Training in progress, step 500
1904ee8 verified
import gc
import os
from dataclasses import dataclass, field
from typing import List, Optional
import torch
from datasets import Dataset, builder, load_dataset
from huggingface_hub import list_repo_refs
from peft import PeftModelForCausalLM
from scalar_rm_model import ScalarModel, ScalarModelConfig
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, Trainer, TrainingArguments
from vllm import LLM, SamplingParams
from vllm.model_executor.parallel_utils.parallel_state import destroy_model_parallel
import wandb
builder.has_sufficient_disk_space = lambda needed_bytes, directory=".": True
@dataclass
class GenerateScriptArguments:
output_dir: Optional[str] = field(
default="/home/toolkit/trl_results",
metadata={"help": "output folder"},
)
num_gpus: Optional[int] = field(default=1)
base_model_name: Optional[str] = field(default=None, metadata={"help": "the model name"})
base_model_revision: Optional[str] = field(default=None)
model_name: Optional[str] = field(default="EleutherAI/pythia-410m", metadata={"help": "the model name"})
model_revisions: Optional[List[str]] = field(default_factory=list)
# base_model_revision: Optional[str] = field(default=None)
tokenizer_name: Optional[str] = field(default=None, metadata={"help": "the tokenizer name"})
dataset_name: Optional[str] = field(
default="arianhosseini/openai_summarize_unlabelled", metadata={"help": "the dataset name"}
)
split: Optional[str] = field(default="validation", metadata={"help": "the dataset name"})
batch_size: Optional[int] = field(default=4)
seq_length: Optional[int] = field(default=512, metadata={"help": "Input sequence length"})
temperature: Optional[float] = field(default=0.7, metadata={"help": "Gen temperature"})
top_p: Optional[float] = field(default=1.0, metadata={"help": "Gen temperature"})
max_new_tokens: Optional[int] = field(default=48, metadata={"help": "max new tokens"})
gen_dtype: Optional[str] = field(default="auto")
@dataclass
class EvalScriptArguments:
wandb_log_id: Optional[str] = field(default=None)
gold_model_name: Optional[str] = field(default="EleutherAI/pythia-410m", metadata={"help": "the model name"})
gold_model_revision: Optional[str] = field(default=None)
eval_dtype: Optional[str] = field(default="auto")
eval_batch_size: Optional[int] = field(default=16)
max_length: Optional[int] = field(default=512)
gold_tokenizer_name: Optional[str] = field(default=None, metadata={"help": "the tokenizer name"})
flash_attention: Optional[bool] = field(default=False)
def generate(script_args):
tokenizer = AutoTokenizer.from_pretrained(script_args.tokenizer_name)
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
tokenizer.padding_side = "left"
dataset = load_dataset(script_args.dataset_name, split=script_args.split)
prompts = dataset["query"]
sampling_params = SamplingParams(
temperature=script_args.temperature,
max_tokens=script_args.max_new_tokens,
top_p=script_args.top_p,
n=1,
include_stop_str_in_output=True,
skip_special_tokens=False,
)
refs = list_repo_refs(script_args.model_name, repo_type="model")
gens = {}
revisions = sorted([branch.name for branch in refs.branches])
for revision in revisions:
if revision == "main":
continue
if script_args.model_revisions and revision not in script_args.model_revisions:
continue
print(f"generating step {revision}")
if script_args.base_model_name is None:
# merged model
model_name = script_args.model_name
revision_name = revision
else:
# peft model that needs to be merged
base_model = AutoModelForCausalLM.from_pretrained(
script_args.base_model_name, revision=script_args.base_model_revision
)
# merge the model and save
model = PeftModelForCausalLM.from_pretrained(
base_model, script_args.model_name, revision=revision, device="cpu"
)
merged = model.merge_and_unload()
model_save_path = f"/home/toolkit/trl_results/{script_args.model_name}_merged/{revision}"
merged.save_pretrained(model_save_path)
del model
del merged
model_name = model_save_path
revision_name = revision
revision = None
llm = LLM(
model=model_name,
revision=revision,
tokenizer=script_args.tokenizer_name,
dtype=script_args.gen_dtype,
max_model_len=script_args.seq_length,
tensor_parallel_size=script_args.num_gpus,
trust_remote_code=True,
)
llm.set_tokenizer(tokenizer)
generations = llm.generate(prompts, sampling_params)
texts = [output.prompt + output.outputs[0].text for output in generations]
gens[revision_name] = texts
dataset = dataset.add_column(f"generations_{revision_name}", texts)
# delete old model
destroy_model_parallel()
del llm.llm_engine.driver_worker
del llm
gc.collect()
torch.cuda.empty_cache()
torch.distributed.destroy_process_group()
if script_args.output_dir is not None:
# TODO add hash to dataset path
# sampling_str = str(sampling_params)
# sampling_hash = hashlib.sha256(sampling_str.encode()).hexdigest()[:10]
dataset_path = os.path.join(
script_args.output_dir,
script_args.dataset_name.replace("/", "_"),
script_args.model_name.replace("/", "_"),
)
os.makedirs(dataset_path, exist_ok=True)
dataset.save_to_disk(dataset_path)
with open(f"{dataset_path}_sampling_params.txt", "w") as f:
print(sampling_params, file=f)
print(f"generated {len(gens)} steps")
reference = dataset["query_reference_response"]
return reference, gens
# ds_info = DatasetInfo(
# f"{script_args.dataset_name} split {script_args.train_split} prompts used to generate with {script_args.model_name}"
# f" temp {script_args.temperature} top_p {script_args.top_p} "
# )
# generated_dataset = Dataset.from_generator(dataset_generator, info=ds_info)
# generated_dataset.push_to_hub(os.path.basename(script_args.output_dir), split="train")
def evaluate(args, reference, generations, model_name=None):
if args.wandb_log_id is not None:
# don't overwrite the wandb name of the original run
if args.wandb_log_id == "model_name":
# model name = config_wandblogid
wandb_log_id = model_name.split("_")[-1]
else:
wandb_log_id = args.wandb_log_id
os.environ.pop("WANDB_NAME")
# original_name = wandb_name.removeprefix("geneval_")
wandb.init(id=wandb_log_id, resume="allow")
log_to_wandb = True
print(f"Logging to WandB {wandb_log_id}")
else:
log_to_wandb = False
torch_dtype = args.eval_dtype if args.eval_dtype in ["auto", None] else getattr(torch, args.eval_dtype)
tokenizer = AutoTokenizer.from_pretrained(args.gold_tokenizer_name)
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
scalar_model_config = ScalarModelConfig.from_pretrained(
args.gold_model_name,
revision=args.gold_model_revision,
)
# hack to remove the path
# models/EleutherAI/pythia-6.9b-deduped/sft_model_55513 -> EleutherAI/pythia-6.9b-deduped
if scalar_model_config.base_model.startswith("models/"):
original_model = scalar_model_config.base_config["_name_or_path"].split("/")[2]
sft_model = f"vwxyzjn/EleutherAI_{original_model}__sft__tldr"
scalar_model_config.base_config["_name_or_path"] = sft_model
scalar_model_config.base_model = sft_model
_, seed, _ = args.gold_model_revision.split("__")
scalar_model_config.base_model_revision = f"sft__{seed}__1708611267"
# quantization_config = get_quantization_config(model_config)
model = ScalarModel.from_pretrained(
args.gold_model_name,
revision=args.gold_model_revision,
config=scalar_model_config,
torch_dtype=torch_dtype,
use_flash_attention_2=args.flash_attention,
)
model.config.pad_token_id = tokenizer.pad_token_id
training_args = TrainingArguments(per_device_eval_batch_size=int(args.eval_batch_size), output_dir=".")
trainer = Trainer(
model=model,
tokenizer=tokenizer,
args=training_args,
)
def tokenize_and_add_eos(tokenizer, text_column, max_length):
def fn(example):
text = example[text_column]
if not text.endswith(tokenizer.eos_token):
text += tokenizer.eos_token
tokenized = tokenizer(
text,
padding="max_length",
max_length=max_length,
truncation=True,
)
# guarantee that last token is EOS if truncated
token_length = sum(tokenized["attention_mask"])
if token_length == max_length:
tokenized["input_ids"][-1] = tokenizer.eos_token_id
return tokenized
return fn
## get reference continuation rewards
dataset = Dataset.from_dict({"reference": reference})
dataset = dataset.map(tokenize_and_add_eos(tokenizer, "reference", args.max_length))
ref_results = trainer.predict(dataset)
ref_rewards = ref_results.predictions
step = 0
for step_str, query_response in generations.items():
dataset = Dataset.from_dict({"query_response": query_response})
dataset = dataset.map(tokenize_and_add_eos(tokenizer, "query_response", args.max_length))
print(f"Evaluating {step_str}")
results = trainer.predict(dataset)
gen_rewards = results.predictions
win_rate = (gen_rewards > ref_rewards).mean().item()
norm_reward = (gen_rewards - ref_rewards).mean().item()
if step_str.startswith("step"):
step_str = step_str.removeprefix("step")
if step_str.isdigit():
step = int(step_str)
else:
print(f"Warning step name {step_str} is not an integer")
step = step + 1
if log_to_wandb:
wandb.log(
{
"gold/win_rate": win_rate,
"gold/norm_reward": norm_reward,
"train/global_step": step,
}
)
print(f"step {step}: win-rate {win_rate} norm-reward {norm_reward}")
def main_args_dict(args_dict):
parser = HfArgumentParser([GenerateScriptArguments, EvalScriptArguments])
generate_args, eval_args = parser.parse_dict(args_dict)
if eval_args.gold_tokenizer_name is None:
eval_args.gold_tokenizer_name = generate_args.tokenizer_name
print("GENERATING")
reference, generations = generate(generate_args)
# dataset = load_dataset(generate_args.dataset_name, split=generate_args.split)
# generations = {"step0": dataset["query_reference_response"]}
# reference = dataset["query_reference_response"]
print("EVALUATING")
evaluate(eval_args, reference, generations, generate_args.model_name)
if __name__ == "__main__":
parser = HfArgumentParser([GenerateScriptArguments, EvalScriptArguments])
generate_args, eval_args = parser.parse_args_into_dataclasses()
if eval_args.gold_tokenizer_name is None:
eval_args.gold_tokenizer_name = generate_args.tokenizer_name
print("GENERATING")
reference, generations = generate(generate_args)
# dataset = load_dataset(generate_args.dataset_name, split=generate_args.train_split)
# generations = {"step0": dataset["query_reference_response"]}
# reference = dataset["query_reference_response"]
print("EVALUATING")
evaluate(eval_args, reference, generations)