# Copyright 2022 - Intel Corp. All rights reserved. # Authors: Mayank Kumar Raunak, Javier Turek, Nicole Beckage """ Implementation of a new method for fine-tuning transformer models that we call Information Gain Filtration 'IGF' on WikiText data set and compared the results with the standard fine-tuning method Steps followed in the code: 1) Generate a objective dataset of pairs (X, IG(X)). IG(X)--Informativeness of context 'X'. Our IG (information gain) model is learning to predict the ‘informativeness’ of a particular context. Informativeness is the change in metric between the model’s accuracy on an objective set before and after seeing that context. For casual language modeling, the metric is perplexity. 2) A secondary learner is trained to infer a function approximation for IG using the dataset created in (1). 3) The learner created in (2) is used to inform the fine-tuning process and filter out low informative samples. Last, a plot is generated to compare the performance of IGF to standard fine-tuning without any filtering """ # Prerequisite libraries: import argparse import random import joblib import numpy as np import torch from igf.igf import ( SecondaryLearner, collect_objective_set, compute_perplexity, generate_datasets, load_gpt2, recopy_gpt2, set_seed, train_secondary_learner, ) from torch.utils.data import DataLoader, RandomSampler from transformers import GPT2LMHeadModel def generate_n_pairs( context_len=32, max_steps=10, size_objective_set=100, min_len=1026, trim=True, data_file="data/tokenized_stories_train_wikitext103.jbl", igf_data_file="igf_context_pairs.jbl", ): """ Collecting *n* pairs for training the secondary learner Args: context_len: The maximum total input sequence length after tokenization. Sequences longer than this will be truncated, sequences shorter will be padded max_steps: To calculate training epochs of secondary learner size_objective_set: size of objective data set used to create (X,IG(X)) pairs which is the training data for secondary learner min_len: The minimum length of the article to be used as objective set trim: If True truncate the context if it exceeds context length data_file: Tokenized data set split for training and evaluation of model igf_data_file: file to store (I,IG(X)) paired data set to train secondary learner Returns: Data stored in igf_data_file """ # generates same data everytime set_seed(3) # generate train_data and objective_set train_data, objective_set = generate_datasets( context_len, data_file, number=size_objective_set, min_len=1026, trim=True ) # keeps model same across runs set_seed(4) # model, lm_optimizer, lm_scheduler = recopy_gpt2(model, device, max_steps) # store original model weights # can we train on GPU? device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # load pretrained model model = load_gpt2("gpt2").to(device) print("computing perplexity on objective set") orig_perp = compute_perplexity(model, objective_set, context_len).item() print("perplexity on objective set:", orig_perp) # collect igf pairs and save to file demo.jbl collect_objective_set(model, orig_perp, context_len, train_data, objective_set, max_steps, device, igf_data_file) # clean up, delete model and data we don't need anymore del model, train_data, objective_set torch.cuda.empty_cache() def training_secondary_learner( secondary_learner_train_data, secondary_learner_max_epochs=15, secondary_learner_batch_size=128, eval_freq=100, igf_model_path="igf_model.pt", ): """ Train the secondary learner Args: secondary_learner_train_data: Data set with (X,IG(X)) pairs to train secondary learner where IG(X) - measure of informativeness and X- context secondary_learner_max_epochs: Number of epochs to train secondary learner secondary_learner_batch_size: Batch size to train secondary learner eval_freq (object): secondary model evaluation can be triggered at eval_freq igf_model_path: path to store trained secondary learner Returns: Trained secondary learner """ set_seed(42) # Load pre-trained model model = GPT2LMHeadModel.from_pretrained("gpt2") # Initialize secondary learner to use embedding weights of model secondary_learner = SecondaryLearner(model) # Train secondary learner secondary_learner = train_secondary_learner( secondary_learner, secondary_learner_train_data, max_epochs=secondary_learner_max_epochs, batch_size=secondary_learner_batch_size, eval_freq=100, igf_model_path=igf_model_path, ) del model, secondary_learner_train_data torch.cuda.empty_cache() return secondary_learner def finetune( model, train_dataset, test_dataset, context_len=32, max_steps=1000, batch_size=16, threshold=1.0, recopy_model=recopy_gpt2, secondary_learner=None, eval_interval=10, finetuned_model_name="gpt2_finetuned.pt", ): """ fine-tune with IGF if secondary_learner is not None, else standard fine-tuning Args: model: pre-trained GPT-2 model train_dataset: Data set to train GPT-2 model test_dataset: Evaluate GPT-2 model context_len: The maximum total input sequence length after tokenization. Sequences longer than this will be truncated, sequences shorter will be padded max_steps: To calculate training epochs batch_size: Batch size to train GPT-2 model threshold: The threshold value used by secondary learner to filter the train_data and allow only" informative data as input to the model recopy_model: Reset the model to the original pretrained GPT-2 weights after each iteration secondary_learner: Selection of IGF as fine-tuning method if not None eval_interval: number of batches after which decay the selectivity of our secondary learner filter from 1 standard deviation above average to 1 below average fine-tuned_model_name: name of the final final-tuned GPT-2 model Returns: Fine-tuned GPT-2 model """ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") train_sampler = RandomSampler(train_dataset) train_dataloader = DataLoader(train_dataset, sampler=train_sampler) num_train_epochs = max_steps // (len(train_dataset)) + 1 global_step = 0 context = torch.zeros((1, context_len), dtype=torch.long, device=device) model, lm_optimizer, lm_scheduler = recopy_model(model, device, max_steps) model.train() if secondary_learner is not None: secondary_learner.to(device) secondary_learner.eval() contexts = [] examples = 0 observed_qs = [] test_perps = [] # Compute the performance of the transformer model at the beginning real_perp = compute_perplexity(model, test_dataset, context_len) test_perps.append(real_perp) print("Test perplexity, step", global_step, ":", real_perp) for epoch in range(int(num_train_epochs)): for step, example in enumerate(train_dataloader): torch.cuda.empty_cache() start = random.randint(0, example.size(2) - context_len - 1) context[0, :] = example[0, 0, start : start + context_len] lm_optimizer.zero_grad() outputs = model(context, labels=context) do_backprop = True if secondary_learner is not None: predicted_q = secondary_learner.forward( torch.tensor(context, dtype=torch.long, device=device).unsqueeze(0) )[0].item() observed_qs.append(float(predicted_q)) # Here we implement the simple non-constant threshold for the predicted IG(X) value # We will decay the selectivity of our secondary learner filter from # 1 standard deviation above average to 1 below average after 10 batches. if global_step == 10: threshold = -1 if predicted_q < threshold: do_backprop = False # If we passed the filter, add the context to the batch! if do_backprop: contexts.append(np.array(context.cpu())) lm_loss = outputs[0] lm_loss.backward() examples += 1 del outputs # Once the batch is filled with enough contexts, backprop on the batch. if examples == batch_size: torch.cuda.empty_cache() examples = 0 # Do LM backprop torch.nn.utils.clip_grad_norm_(model.parameters(), 3.0) lm_optimizer.step() lm_scheduler.step() # Update learning rate schedule global_step += 1 # Compute the performance of the transformer model at this batch if global_step % eval_interval == 0: real_perp = compute_perplexity(model, test_dataset, context_len) test_perps.append(real_perp) print("Test perplexity, step", global_step, ":", real_perp) # Break out of the loop after 60 batches if max_steps > 0 and global_step > 60: break if max_steps > 0 and global_step > 60: break # save finetuned transformer model torch.save(model.state_dict(), finetuned_model_name) torch.cuda.empty_cache() # Do some cleaning up so we can reinitialize for the next run of this function del lm_optimizer del lm_scheduler return model def main(): parser = argparse.ArgumentParser(description="Fine-tune a transformer model with IGF on a language modeling task") # Required parameters parser.add_argument( "--data_dir", default=None, type=str, required=True, help="The input data dir. Should contain data files for WikiText.", ) parser.add_argument( "--model_name_or_path", default=None, type=str, required=True, help="Path to pretrained model or model identifier from huggingface.co/models", ) parser.add_argument( "--data_file", type=str, default=None, help=( "A jbl file containing tokenized data which can be split as objective dataset, " "train_dataset and test_dataset." ), ) parser.add_argument( "--igf_data_file", type=str, default=None, help="A jbl file containing the context and information gain pairs to train secondary learner.", ) parser.add_argument( "--output_dir", default=None, type=str, required=True, help="The output directory where the final fine-tuned model is stored.", ) parser.add_argument( "--tokenizer_name", default=None, type=str, help="Pretrained tokenizer name or path if not the same as model_name", ) parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") parser.add_argument( "--context_len", default=32, type=int, help=( "The maximum total input sequence length after tokenization. Sequences longer " "than this will be truncated, sequences shorter will be padded." ), ) parser.add_argument( "--size_objective_set", default=100, type=int, help="number of articles that are long enough to be used as our objective set", ) parser.add_argument( "--eval_freq", default=100, type=int, help="secondary model evaluation is triggered at eval_freq" ) parser.add_argument("--max_steps", default=1000, type=int, help="To calculate training epochs") parser.add_argument( "--secondary_learner_batch_size", default=128, type=int, help="batch size of training data for secondary learner", ) parser.add_argument( "--batch_size", default=16, type=int, help="batch size of training data of language model(gpt2) " ) parser.add_argument( "--eval_interval", default=10, type=int, help=( "decay the selectivity of our secondary learner filter from" "1 standard deviation above average to 1 below average after 10 batches" ), ) parser.add_argument( "--number", default=100, type=int, help="The number of examples split to be used as objective_set/test_data" ) parser.add_argument( "--min_len", default=1026, type=int, help="The minimum length of the article to be used as objective set" ) parser.add_argument( "--secondary_learner_max_epochs", default=15, type=int, help="number of epochs to train secondary learner" ) parser.add_argument("--trim", default=True, type=bool, help="truncate the example if it exceeds context length") parser.add_argument( "--threshold", default=1.0, type=float, help=( "The threshold value used by secondary learner to filter the train_data and allow only" " informative data as input to the model" ), ) parser.add_argument("--finetuned_model_name", default="gpt2_finetuned.pt", type=str, help="finetuned_model_name") parser.add_argument( "--recopy_model", default=recopy_gpt2, type=str, help="Reset the model to the original pretrained GPT-2 weights after each iteration", ) # function calls # Collecting *n* pairs of context and information gain(X, IG(X)) for training the secondary learner generate_n_pairs( context_len=32, max_steps=10, size_objective_set=100, min_len=1026, trim=True, data_file="data/tokenized_stories_train_wikitext103.jbl", igf_data_file="igf_context_pairs.jbl", ) # Load train data for secondary learner secondary_learner_train_data = joblib.load("data/IGF_values.jbl") # Train secondary learner secondary_learner = training_secondary_learner( secondary_learner_train_data, secondary_learner_max_epochs=15, secondary_learner_batch_size=128, eval_freq=100, igf_model_path="igf_model.pt", ) # load pretrained gpt2 model model = GPT2LMHeadModel.from_pretrained("gpt2") set_seed(42) # Generate train and test data to train and evaluate gpt2 model train_dataset, test_dataset = generate_datasets( context_len=32, file="data/tokenized_stories_train_wikitext103.jbl", number=100, min_len=1026, trim=True ) # fine-tuning of the gpt2 model using igf (Information Gain Filtration) finetune( model, train_dataset, test_dataset, context_len=32, max_steps=1000, batch_size=16, threshold=1.0, recopy_model=recopy_gpt2, secondary_learner=secondary_learner, eval_interval=10, finetuned_model_name="gpt2_finetuned.pt", ) if __name__ == "__main__": main()