|
|
|
|
|
"""Train a MiniHF evaluator model (instruction tuned LoRA).""" |
|
|
|
import argparse |
|
from functools import partial |
|
import os |
|
from pathlib import Path |
|
import sys |
|
|
|
os.environ["BITSANDBYTES_NOWELCOME"] = "1" |
|
|
|
import accelerate |
|
import datasets |
|
import datasets.distributed |
|
import peft |
|
import torch |
|
from torch import optim |
|
from torch.nn import functional as F |
|
from torch.utils import data |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig |
|
from tqdm import tqdm |
|
|
|
print = tqdm.external_write_mode()(print) |
|
|
|
|
|
def batch_to_tensors(batch, device="cpu"): |
|
batch = [item["input_ids"] for item in batch] |
|
seq_len = max(len(x) for x in batch) |
|
input_ids = torch.zeros(len(batch), seq_len, dtype=torch.long, device=device) |
|
attention_mask = torch.zeros(len(batch), seq_len, dtype=torch.long, device=device) |
|
for i, x in enumerate(batch): |
|
input_ids[i, : len(x)] = torch.tensor(x, dtype=torch.long, device=device) |
|
attention_mask[i, : len(x)] = 1 |
|
return input_ids, attention_mask |
|
|
|
|
|
def weighted_mean(x, w=None, dim=None, keepdim=False, dtype=None): |
|
w = x.new_tensor(1.0) if w is None else w |
|
w = w.expand_as(x) |
|
dim = tuple(range(x.ndim)) if dim is None else dim |
|
num = torch.sum(x * w, dim=dim, keepdim=keepdim, dtype=dtype) |
|
denom = torch.sum(w, dim=dim, keepdim=keepdim, dtype=dtype) |
|
return num / denom |
|
|
|
|
|
class EndlessHFDataset(data.IterableDataset): |
|
def __init__(self, dataset): |
|
super().__init__() |
|
self.dataset = dataset |
|
|
|
def __iter__(self): |
|
while True: |
|
yield from self.dataset |
|
self.dataset.set_epoch(self.dataset._epoch + 1) |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser( |
|
description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter |
|
) |
|
parser.add_argument("--batch-size", type=int, default=4, help="batch size per process") |
|
parser.add_argument("--examples", type=int, default=100000, help="train for n examples") |
|
parser.add_argument("--output-dir", type=Path, default="evaluator", help="output directory") |
|
parser.add_argument("--save-every", type=int, default=10000, help="save every n examples") |
|
args = parser.parse_args() |
|
|
|
dataset_seed = 100 |
|
lora_rank = 32 |
|
lr = 1e-4 |
|
max_len = 2048 |
|
model_name = "openlm-research/open_llama_7b" |
|
|
|
|
|
accelerator = accelerate.Accelerator(mixed_precision="bf16", dispatch_batches=False) |
|
device = accelerator.device |
|
print0 = accelerator.on_local_main_process(print) |
|
|
|
|
|
print0(f"### Loading tokenizer: {model_name}", file=sys.stderr) |
|
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) |
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
print0(f"### Loading model: {model_name}", file=sys.stderr) |
|
bnb_config = BitsAndBytesConfig( |
|
load_in_4bit=True, |
|
bnb_4bit_compute_dtype=torch.bfloat16, |
|
bnb_4bit_quant_type="nf4", |
|
bnb_4bit_use_double_quant=True, |
|
) |
|
with accelerator.main_process_first(): |
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_name, |
|
device_map="auto" if accelerator.num_processes == 1 else {"": device}, |
|
quantization_config=bnb_config, |
|
torch_dtype=torch.bfloat16, |
|
trust_remote_code=True, |
|
) |
|
accelerator.wait_for_everyone() |
|
|
|
|
|
print0("### Setting up the LoRA", file=sys.stderr) |
|
peft_config = peft.LoraConfig( |
|
peft.TaskType.CAUSAL_LM, |
|
inference_mode=False, |
|
r=lora_rank, |
|
lora_alpha=8, |
|
lora_dropout=0.0, |
|
target_modules=[ |
|
"self_attn.q_proj", |
|
"self_attn.k_proj", |
|
"self_attn.v_proj", |
|
"self_attn.o_proj", |
|
"mlp.gate_proj", |
|
"mlp.up_proj", |
|
"mlp.down_proj", |
|
"lm_head", |
|
], |
|
) |
|
model = peft.get_peft_model(model, peft_config) |
|
accelerator.wait_for_everyone() |
|
|
|
|
|
model.train() |
|
model.gradient_checkpointing_enable() |
|
model.enable_input_require_grads() |
|
if accelerator.is_local_main_process: |
|
model.print_trainable_parameters() |
|
|
|
|
|
def combine_flan(row): |
|
return row["inputs"] + "<|end|>" + row["targets"] + tokenizer.eos_token |
|
|
|
def combine_dolly(row): |
|
return ( |
|
row["context"] |
|
+ "\n\n" |
|
+ row["instruction"] |
|
+ "<|end|>" |
|
+ row["response"] |
|
+ tokenizer.eos_token |
|
) |
|
|
|
def to_tokens(combine_fn, row): |
|
return tokenizer(combine_fn(row)) |
|
|
|
def exclude_too_long(row): |
|
return len(row["input_ids"]) <= max_len |
|
|
|
|
|
print0("### Loading datasets", file=sys.stderr) |
|
with accelerator.main_process_first(): |
|
dataset_1 = datasets.load_dataset("Muennighoff/flan", streaming=True) |
|
dataset_2 = datasets.load_dataset("databricks/databricks-dolly-15k", streaming=True) |
|
accelerator.wait_for_everyone() |
|
dataset_1 = dataset_1["train"].map(partial(to_tokens, combine_flan)) |
|
dataset_2 = dataset_2["train"].map(partial(to_tokens, combine_dolly)) |
|
dataset = ( |
|
datasets.interleave_datasets([dataset_1, dataset_2], probabilities=[0.9, 0.1]) |
|
.filter(exclude_too_long) |
|
.shuffle(seed=dataset_seed) |
|
.select_columns(["input_ids"]) |
|
) |
|
dataset = datasets.distributed.split_dataset_by_node( |
|
dataset, accelerator.process_index, accelerator.num_processes |
|
) |
|
dataloader = data.DataLoader( |
|
EndlessHFDataset(dataset), |
|
batch_size=args.batch_size, |
|
collate_fn=batch_to_tensors, |
|
drop_last=True, |
|
) |
|
|
|
|
|
opt = optim.Adam(model.parameters(), lr=lr, betas=(0.9, 0.99)) |
|
|
|
|
|
model, opt, dataloader = accelerator.prepare(model, opt, dataloader) |
|
|
|
|
|
print0("### Testing max sequence length", file=sys.stderr) |
|
input_ids = torch.zeros([args.batch_size, max_len], dtype=torch.long, device=device) |
|
attention_mask = torch.ones([args.batch_size, max_len], dtype=torch.long, device=device) |
|
outputs = model(input_ids, attention_mask=attention_mask, use_cache=False) |
|
accelerator.backward(outputs.logits.sum() * 0) |
|
opt.zero_grad() |
|
torch.cuda.empty_cache() |
|
|
|
def save_model(): |
|
print0("### Saving model", file=sys.stderr) |
|
accelerator.wait_for_everyone() |
|
if accelerator.is_main_process: |
|
unwrapped_model = accelerator.unwrap_model(model) |
|
unwrapped_model.save_pretrained(args.output_dir, safe_serialization=True) |
|
tokenizer.save_pretrained(args.output_dir) |
|
|
|
|
|
print0("### Training", file=sys.stderr) |
|
examples = 0 |
|
last_save = 0 |
|
pbar = tqdm( |
|
disable=not accelerator.is_local_main_process, |
|
total=args.examples, |
|
unit="ex", |
|
smoothing=0.01, |
|
) |
|
|
|
try: |
|
for batch in dataloader: |
|
input_ids, attention_mask = batch |
|
with accelerator.accumulate(model): |
|
|
|
outputs = model( |
|
input_ids[:, :-1], |
|
attention_mask=attention_mask[:, :-1], |
|
use_cache=False, |
|
) |
|
losses = F.cross_entropy( |
|
outputs.logits.transpose(-1, -2), |
|
input_ids[:, 1:], |
|
reduction="none", |
|
) |
|
mask = attention_mask[:, :-1] * attention_mask[:, 1:] |
|
loss = weighted_mean(losses, mask, dtype=torch.float32) |
|
|
|
|
|
accelerator.backward(loss) |
|
opt.step() |
|
opt.zero_grad() |
|
|
|
global_batch_size = args.batch_size * accelerator.num_processes |
|
examples += global_batch_size |
|
pbar.update(global_batch_size) |
|
|
|
global_loss = accelerator.reduce(loss, "mean") |
|
print0(f"examples: {examples}, loss: {global_loss.item():g}") |
|
|
|
if examples >= args.examples: |
|
save_model() |
|
break |
|
|
|
if examples - last_save >= args.save_every: |
|
save_model() |
|
last_save += args.save_every |
|
|
|
except KeyboardInterrupt: |
|
pass |
|
|
|
finally: |
|
pbar.close() |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|