Spaces:
Paused
Paused
voice_clone_v3
/
transformers
/examples
/research_projects
/information-gain-filtration
/run_clm_igf.py
# Copyright 2022 - Intel Corp. All rights reserved. | |
# Authors: Mayank Kumar Raunak, Javier Turek, Nicole Beckage | |
""" | |
Implementation of a new method for fine-tuning transformer models that we call | |
Information Gain Filtration 'IGF' on WikiText data set and compared the results | |
with the standard fine-tuning method | |
Steps followed in the code: | |
1) Generate a objective dataset of pairs (X, IG(X)). IG(X)--Informativeness of context 'X'. | |
Our IG (information gain) model is learning to predict the ‘informativeness’ of a particular | |
context. Informativeness is the change in metric between the model’s accuracy on an | |
objective set before and after seeing that context. For casual language modeling, the | |
metric is perplexity. | |
2) A secondary learner is trained to infer a function approximation for IG using the dataset | |
created in (1). | |
3) The learner created in (2) is used to inform the fine-tuning process and filter out low informative samples. | |
Last, a plot is generated to compare the performance of IGF to standard fine-tuning without any filtering | |
""" | |
# Prerequisite libraries: | |
import argparse | |
import random | |
import joblib | |
import numpy as np | |
import torch | |
from igf.igf import ( | |
SecondaryLearner, | |
collect_objective_set, | |
compute_perplexity, | |
generate_datasets, | |
load_gpt2, | |
recopy_gpt2, | |
set_seed, | |
train_secondary_learner, | |
) | |
from torch.utils.data import DataLoader, RandomSampler | |
from transformers import GPT2LMHeadModel | |
def generate_n_pairs( | |
context_len=32, | |
max_steps=10, | |
size_objective_set=100, | |
min_len=1026, | |
trim=True, | |
data_file="data/tokenized_stories_train_wikitext103.jbl", | |
igf_data_file="igf_context_pairs.jbl", | |
): | |
""" | |
Collecting *n* pairs for training the secondary learner | |
Args: | |
context_len: The maximum total input sequence length after tokenization. Sequences longer | |
than this will be truncated, sequences shorter will be padded | |
max_steps: To calculate training epochs of secondary learner | |
size_objective_set: size of objective data set used to create (X,IG(X)) pairs which is the training data for secondary learner | |
min_len: The minimum length of the article to be used as objective set | |
trim: If True truncate the context if it exceeds context length | |
data_file: Tokenized data set split for training and evaluation of model | |
igf_data_file: file to store (I,IG(X)) paired data set to train secondary learner | |
Returns: | |
Data stored in igf_data_file | |
""" | |
# generates same data everytime | |
set_seed(3) | |
# generate train_data and objective_set | |
train_data, objective_set = generate_datasets( | |
context_len, data_file, number=size_objective_set, min_len=1026, trim=True | |
) | |
# keeps model same across runs | |
set_seed(4) | |
# model, lm_optimizer, lm_scheduler = recopy_gpt2(model, device, max_steps) # store original model weights | |
# can we train on GPU? | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
# load pretrained model | |
model = load_gpt2("gpt2").to(device) | |
print("computing perplexity on objective set") | |
orig_perp = compute_perplexity(model, objective_set, context_len).item() | |
print("perplexity on objective set:", orig_perp) | |
# collect igf pairs and save to file demo.jbl | |
collect_objective_set(model, orig_perp, context_len, train_data, objective_set, max_steps, device, igf_data_file) | |
# clean up, delete model and data we don't need anymore | |
del model, train_data, objective_set | |
torch.cuda.empty_cache() | |
def training_secondary_learner( | |
secondary_learner_train_data, | |
secondary_learner_max_epochs=15, | |
secondary_learner_batch_size=128, | |
eval_freq=100, | |
igf_model_path="igf_model.pt", | |
): | |
""" | |
Train the secondary learner | |
Args: | |
secondary_learner_train_data: Data set with (X,IG(X)) pairs to train secondary learner where IG(X) - measure of informativeness and X- context | |
secondary_learner_max_epochs: Number of epochs to train secondary learner | |
secondary_learner_batch_size: Batch size to train secondary learner | |
eval_freq (object): secondary model evaluation can be triggered at eval_freq | |
igf_model_path: path to store trained secondary learner | |
Returns: | |
Trained secondary learner | |
""" | |
set_seed(42) | |
# Load pre-trained model | |
model = GPT2LMHeadModel.from_pretrained("gpt2") | |
# Initialize secondary learner to use embedding weights of model | |
secondary_learner = SecondaryLearner(model) | |
# Train secondary learner | |
secondary_learner = train_secondary_learner( | |
secondary_learner, | |
secondary_learner_train_data, | |
max_epochs=secondary_learner_max_epochs, | |
batch_size=secondary_learner_batch_size, | |
eval_freq=100, | |
igf_model_path=igf_model_path, | |
) | |
del model, secondary_learner_train_data | |
torch.cuda.empty_cache() | |
return secondary_learner | |
def finetune( | |
model, | |
train_dataset, | |
test_dataset, | |
context_len=32, | |
max_steps=1000, | |
batch_size=16, | |
threshold=1.0, | |
recopy_model=recopy_gpt2, | |
secondary_learner=None, | |
eval_interval=10, | |
finetuned_model_name="gpt2_finetuned.pt", | |
): | |
""" | |
fine-tune with IGF if secondary_learner is not None, else standard fine-tuning | |
Args: | |
model: pre-trained GPT-2 model | |
train_dataset: Data set to train GPT-2 model | |
test_dataset: Evaluate GPT-2 model | |
context_len: The maximum total input sequence length after tokenization. Sequences longer | |
than this will be truncated, sequences shorter will be padded | |
max_steps: To calculate training epochs | |
batch_size: Batch size to train GPT-2 model | |
threshold: The threshold value used by secondary learner to filter the train_data and allow only" | |
informative data as input to the model | |
recopy_model: Reset the model to the original pretrained GPT-2 weights after each iteration | |
secondary_learner: Selection of IGF as fine-tuning method if not None | |
eval_interval: number of batches after which decay the selectivity of our secondary learner filter from | |
1 standard deviation above average to 1 below average | |
fine-tuned_model_name: name of the final final-tuned GPT-2 model | |
Returns: | |
Fine-tuned GPT-2 model | |
""" | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
train_sampler = RandomSampler(train_dataset) | |
train_dataloader = DataLoader(train_dataset, sampler=train_sampler) | |
num_train_epochs = max_steps // (len(train_dataset)) + 1 | |
global_step = 0 | |
context = torch.zeros((1, context_len), dtype=torch.long, device=device) | |
model, lm_optimizer, lm_scheduler = recopy_model(model, device, max_steps) | |
model.train() | |
if secondary_learner is not None: | |
secondary_learner.to(device) | |
secondary_learner.eval() | |
contexts = [] | |
examples = 0 | |
observed_qs = [] | |
test_perps = [] | |
# Compute the performance of the transformer model at the beginning | |
real_perp = compute_perplexity(model, test_dataset, context_len) | |
test_perps.append(real_perp) | |
print("Test perplexity, step", global_step, ":", real_perp) | |
for epoch in range(int(num_train_epochs)): | |
for step, example in enumerate(train_dataloader): | |
torch.cuda.empty_cache() | |
start = random.randint(0, example.size(2) - context_len - 1) | |
context[0, :] = example[0, 0, start : start + context_len] | |
lm_optimizer.zero_grad() | |
outputs = model(context, labels=context) | |
do_backprop = True | |
if secondary_learner is not None: | |
predicted_q = secondary_learner.forward( | |
torch.tensor(context, dtype=torch.long, device=device).unsqueeze(0) | |
)[0].item() | |
observed_qs.append(float(predicted_q)) | |
# Here we implement the simple non-constant threshold for the predicted IG(X) value | |
# We will decay the selectivity of our secondary learner filter from | |
# 1 standard deviation above average to 1 below average after 10 batches. | |
if global_step == 10: | |
threshold = -1 | |
if predicted_q < threshold: | |
do_backprop = False | |
# If we passed the filter, add the context to the batch! | |
if do_backprop: | |
contexts.append(np.array(context.cpu())) | |
lm_loss = outputs[0] | |
lm_loss.backward() | |
examples += 1 | |
del outputs | |
# Once the batch is filled with enough contexts, backprop on the batch. | |
if examples == batch_size: | |
torch.cuda.empty_cache() | |
examples = 0 | |
# Do LM backprop | |
torch.nn.utils.clip_grad_norm_(model.parameters(), 3.0) | |
lm_optimizer.step() | |
lm_scheduler.step() # Update learning rate schedule | |
global_step += 1 | |
# Compute the performance of the transformer model at this batch | |
if global_step % eval_interval == 0: | |
real_perp = compute_perplexity(model, test_dataset, context_len) | |
test_perps.append(real_perp) | |
print("Test perplexity, step", global_step, ":", real_perp) | |
# Break out of the loop after 60 batches | |
if max_steps > 0 and global_step > 60: | |
break | |
if max_steps > 0 and global_step > 60: | |
break | |
# save finetuned transformer model | |
torch.save(model.state_dict(), finetuned_model_name) | |
torch.cuda.empty_cache() | |
# Do some cleaning up so we can reinitialize for the next run of this function | |
del lm_optimizer | |
del lm_scheduler | |
return model | |
def main(): | |
parser = argparse.ArgumentParser(description="Fine-tune a transformer model with IGF on a language modeling task") | |
# Required parameters | |
parser.add_argument( | |
"--data_dir", | |
default=None, | |
type=str, | |
required=True, | |
help="The input data dir. Should contain data files for WikiText.", | |
) | |
parser.add_argument( | |
"--model_name_or_path", | |
default=None, | |
type=str, | |
required=True, | |
help="Path to pretrained model or model identifier from huggingface.co/models", | |
) | |
parser.add_argument( | |
"--data_file", | |
type=str, | |
default=None, | |
help=( | |
"A jbl file containing tokenized data which can be split as objective dataset, " | |
"train_dataset and test_dataset." | |
), | |
) | |
parser.add_argument( | |
"--igf_data_file", | |
type=str, | |
default=None, | |
help="A jbl file containing the context and information gain pairs to train secondary learner.", | |
) | |
parser.add_argument( | |
"--output_dir", | |
default=None, | |
type=str, | |
required=True, | |
help="The output directory where the final fine-tuned model is stored.", | |
) | |
parser.add_argument( | |
"--tokenizer_name", | |
default=None, | |
type=str, | |
help="Pretrained tokenizer name or path if not the same as model_name", | |
) | |
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") | |
parser.add_argument( | |
"--context_len", | |
default=32, | |
type=int, | |
help=( | |
"The maximum total input sequence length after tokenization. Sequences longer " | |
"than this will be truncated, sequences shorter will be padded." | |
), | |
) | |
parser.add_argument( | |
"--size_objective_set", | |
default=100, | |
type=int, | |
help="number of articles that are long enough to be used as our objective set", | |
) | |
parser.add_argument( | |
"--eval_freq", default=100, type=int, help="secondary model evaluation is triggered at eval_freq" | |
) | |
parser.add_argument("--max_steps", default=1000, type=int, help="To calculate training epochs") | |
parser.add_argument( | |
"--secondary_learner_batch_size", | |
default=128, | |
type=int, | |
help="batch size of training data for secondary learner", | |
) | |
parser.add_argument( | |
"--batch_size", default=16, type=int, help="batch size of training data of language model(gpt2) " | |
) | |
parser.add_argument( | |
"--eval_interval", | |
default=10, | |
type=int, | |
help=( | |
"decay the selectivity of our secondary learner filter from" | |
"1 standard deviation above average to 1 below average after 10 batches" | |
), | |
) | |
parser.add_argument( | |
"--number", default=100, type=int, help="The number of examples split to be used as objective_set/test_data" | |
) | |
parser.add_argument( | |
"--min_len", default=1026, type=int, help="The minimum length of the article to be used as objective set" | |
) | |
parser.add_argument( | |
"--secondary_learner_max_epochs", default=15, type=int, help="number of epochs to train secondary learner" | |
) | |
parser.add_argument("--trim", default=True, type=bool, help="truncate the example if it exceeds context length") | |
parser.add_argument( | |
"--threshold", | |
default=1.0, | |
type=float, | |
help=( | |
"The threshold value used by secondary learner to filter the train_data and allow only" | |
" informative data as input to the model" | |
), | |
) | |
parser.add_argument("--finetuned_model_name", default="gpt2_finetuned.pt", type=str, help="finetuned_model_name") | |
parser.add_argument( | |
"--recopy_model", | |
default=recopy_gpt2, | |
type=str, | |
help="Reset the model to the original pretrained GPT-2 weights after each iteration", | |
) | |
# function calls | |
# Collecting *n* pairs of context and information gain(X, IG(X)) for training the secondary learner | |
generate_n_pairs( | |
context_len=32, | |
max_steps=10, | |
size_objective_set=100, | |
min_len=1026, | |
trim=True, | |
data_file="data/tokenized_stories_train_wikitext103.jbl", | |
igf_data_file="igf_context_pairs.jbl", | |
) | |
# Load train data for secondary learner | |
secondary_learner_train_data = joblib.load("data/IGF_values.jbl") | |
# Train secondary learner | |
secondary_learner = training_secondary_learner( | |
secondary_learner_train_data, | |
secondary_learner_max_epochs=15, | |
secondary_learner_batch_size=128, | |
eval_freq=100, | |
igf_model_path="igf_model.pt", | |
) | |
# load pretrained gpt2 model | |
model = GPT2LMHeadModel.from_pretrained("gpt2") | |
set_seed(42) | |
# Generate train and test data to train and evaluate gpt2 model | |
train_dataset, test_dataset = generate_datasets( | |
context_len=32, file="data/tokenized_stories_train_wikitext103.jbl", number=100, min_len=1026, trim=True | |
) | |
# fine-tuning of the gpt2 model using igf (Information Gain Filtration) | |
finetune( | |
model, | |
train_dataset, | |
test_dataset, | |
context_len=32, | |
max_steps=1000, | |
batch_size=16, | |
threshold=1.0, | |
recopy_model=recopy_gpt2, | |
secondary_learner=secondary_learner, | |
eval_interval=10, | |
finetuned_model_name="gpt2_finetuned.pt", | |
) | |
if __name__ == "__main__": | |
main() | |