lorahub / util.py
SivilTaram
update demo
470be5c
raw
history blame contribute delete
No virus
6.06 kB
from transformers import AutoModelForSeq2SeqLM
import torch
from datasets import Dataset
from torch.utils.data import DataLoader
from transformers import default_data_collator
from transformers import AutoTokenizer
from tqdm import tqdm
import pandas as pd
import numpy
import random
import nevergrad as ng
from peft.utils.save_and_load import set_peft_model_state_dict, get_peft_model_state_dict
from peft import PeftModel, PeftConfig
from functools import partial
random.seed(42)
numpy.random.seed(42)
def load_base_model_and_lora_modules(lora_module_list):
# use gpu if available
device = "cuda" if torch.cuda.is_available() else "cpu"
# load basic model
default_peft_model_id = lora_module_list[0]
# find the base model
model_name_or_path = PeftConfig.from_pretrained(default_peft_model_id).base_model_name_or_path
base_model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path)
# load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
# 0 is the default model
peft_model = PeftModel.from_pretrained(base_model, default_peft_model_id)
peft_model = peft_model.to(device)
peft_model.eval()
print("> Begin to load lora modules")
cache = {}
for peft_model_id in tqdm(lora_module_list):
print("> Loading {} ...".format(peft_model_id))
cur_peft_model = PeftModel.from_pretrained(base_model, peft_model_id)
cache[peft_model_id] = get_peft_model_state_dict(cur_peft_model)
return peft_model, tokenizer, cache
def preprocess_function(examples, tokenizer):
inputs = examples["input"]
targets = examples["output"]
model_inputs = tokenizer(
inputs,
max_length=2048,
padding=True,
truncation=True,
return_tensors="pt",
)
labels = tokenizer(
targets,
max_length=256,
padding=True,
truncation=True,
return_tensors="pt",
)
labels = labels["input_ids"]
labels[labels == tokenizer.pad_token_id] = -100
model_inputs["labels"] = labels
return model_inputs
def load_dataset_and_run(example_inputs, example_outputs, tokenizer):
df = [
{"input": example_inputs[i], "output": example_outputs[i]}
for i in range(len(example_inputs))
]
dataset = Dataset.from_pandas(pd.DataFrame(df))
preprocess_func_with_tokenizer = partial(preprocess_function, tokenizer=tokenizer)
processed_datasets = dataset.map(
preprocess_func_with_tokenizer,
batched=True,
num_proc=1,
desc="Running tokenizer on dataset",
)
return processed_datasets
def get_score(weights, model, cache, example_dataset):
# the composed lora state dict
final_state_dict = {}
# module list is the list
lora_module_list = list(cache.keys())
# all keys are the same
keys = cache[lora_module_list[0]].keys()
for i, peft_model_id in enumerate(lora_module_list):
lora_state_dict = cache[peft_model_id]
if i == 0:
for key in keys:
final_state_dict[key] = weights[i] * lora_state_dict[key]
else:
for key in keys:
final_state_dict[key] = (
final_state_dict[key] + weights[i] * lora_state_dict[key]
)
# reload the model with the new adapter config
set_peft_model_state_dict(model, final_state_dict)
def get_loss():
# use gpu if available
train_dataset = example_dataset
train_dataloader = DataLoader(
train_dataset,
collate_fn=default_data_collator,
batch_size=len(train_dataset),
pin_memory=True,
)
train_loss = 0
with torch.no_grad():
device = "cuda" if torch.cuda.is_available() else "cpu"
for _, batch in enumerate(train_dataloader):
batch = {k: v.to(device) for k, v in batch.items()}
with torch.no_grad():
outputs = model(**batch)
loss = outputs.loss
train_loss += loss.detach().float()
loss = train_loss.float()
return float(loss) / len(train_dataset["input"])
# minimize the metric
loss = get_loss()
# L1 regularization term
sum_of_squares = sum([abs(x) for x in weights]) / len(weights)
metric_val = loss + 0.05 * sum_of_squares
return metric_val
def get_final_weights(weights, lora_module_list, cache):
final_state_dict = {}
keys = cache[lora_module_list[0]].keys()
for i, peft_model_id in enumerate(lora_module_list):
lora_state_dict = cache[peft_model_id]
if i == 0:
for key in keys:
final_state_dict[key] = weights[i] * lora_state_dict[key]
else:
for key in keys:
final_state_dict[key] = (
final_state_dict[key] + weights[i] * lora_state_dict[key]
)
return final_state_dict
def lorahub_learning(lora_module_list, text_input, text_output, max_inference_step):
number_of_loras = len(lora_module_list)
if number_of_loras == 0:
return None
# load model
model, tokenizer, cache = load_base_model_and_lora_modules(lora_module_list)
# process dataset
dataset = load_dataset_and_run(text_input.split("\n"), text_output.split("\n"), tokenizer)
get_score_partial = partial(get_score, model=model, cache=cache,
example_dataset=dataset)
# set up the limit of the weights
instrum = ng.p.Array(
init=[0] * number_of_loras,
upper=[1.5] * number_of_loras,
lower=[-1.5] * number_of_loras,
)
optimizer = ng.optimizers.NGOpt(parametrization=instrum, budget=max_inference_step)
print("> Begin to perform gradient-free optimization ...")
recommendation = optimizer.minimize(get_score_partial, verbosity=1)
final_lora = get_final_weights(recommendation.value, lora_module_list, cache)
return recommendation, final_lora