EasyPrompt / sft.py
Trace2333's picture
initial commit
c700ce7
raw
history blame
2.82 kB
import time
import evaluate
import numpy as np
from transformers import DataCollatorForLanguageModeling, DataCollatorForSeq2Seq
from transformers import TrainingArguments, Trainer
from utils import (
get_dataset,
get_tok_and_model,
get_open_prompt_data,
get_dict_dataset,
get_advance_dataset,)
base_model = "distilgpt2"
tokenizer, model = get_tok_and_model(f"./models/{base_model}")
tokenizer.pad_token = tokenizer.eos_token
rouge = evaluate.load("rouge")
# train_data, test_data = get_open_prompt_data("./data")
# train_dataset, test_dataset = get_dataset(train_data, test_data)
dict_data = get_dict_dataset("./data")
dataset = get_advance_dataset(dict_data)
dataset = dataset.train_test_split(test_size=0.2)
def preprocess_function(examples):
x_inputs = [x for x in examples["x"]]
y_inputs = examples["y"]
model_inputs = tokenizer(x_inputs, max_length=128, truncation=True)
labels = tokenizer(text_target=y_inputs, max_length=128, truncation=True)
model_inputs["labels"] = model_inputs["input_ids"]
return model_inputs
def compute_metrics(eval_pred):
predictions, labels = eval_pred
decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
result = rouge.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
result["gen_len"] = np.mean(prediction_lens)
return {k: round(v, 4) for k, v in result.items()}
# data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)
print("tokenize data...")
t1 = time.time()
tokenized_dataset = dataset.map(preprocess_function, batched=True, remove_columns=["x", "y"])
t2 = time.time()
print(f"data tokenize done. process time : {t2 - t1}")
training_args = TrainingArguments(
output_dir=f"./output/{base_model}_openprpmpt",
evaluation_strategy="steps",
eval_steps=20000,
learning_rate=2e-5,
lr_scheduler_type="constant",
report_to="tensorboard",
per_device_train_batch_size=64,
per_device_eval_batch_size=32,
adam_beta1=0.9,
adam_beta2=0.98,
save_total_limit=1,
num_train_epochs=100,
fp16=True,
push_to_hub=False,
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset["train"],
eval_dataset=tokenized_dataset["test"],
tokenizer=tokenizer,
data_collator=data_collator,
)
trainer.train()
import math
eval_results = trainer.evaluate()
print(f"Perplexity: {math.exp(eval_results['eval_loss']):.2f}")