Fine-tuning BLOOM for Summarization with Trainer API

#234
by monta - opened

Objective

Trying to fine tune BLOOM for Summarization using Trainer. Would like to get advice/suggestion if the code below can fine-tune the model as there are not many examples for fine-tuning using Trainer for BLOOM.

Background

The pipeline "summarization" task does not support BLOOM and AutoModel for Seq2Seq does not work as BLOOM is not encoder/decoder model, hence need to come up with a different approach. BLOOMZ uses Megatron but the learning challenge is too high. There is a discussion Fine-tune the model?#46 but conclusion is not clear and it is for QA task.

Question.

Thinking that tokenized prompt as the 'input_ids' and tokenized summary as 'labels' as the training data to the model as below but not sure this is a correct approach or not. Please advise if this works, beyond all, if Trainer is fit for purpose.

DataCollatorWithPadding class does not pad the 'labels' element, which causes an error at train(). Hence used padding at tokenizer to pad labels but not sure this is correct. Please advise if there is another way to manage labels.

Please also give correction/suggestion if any.

Code

import re
from typing import (
    List,
    Dict,
    Callable,
)

import evaluate
import numpy as np
from datasets import (
    load_dataset,
    get_dataset_split_names
)
from promptsource.templates import (
    DatasetTemplates,
    Template
)
from transformers import (
    AutoTokenizer,
    DataCollatorWithPadding,
    DataCollatorForSeq2Seq,
    BloomForCausalLM,
    TrainingArguments,
    Trainer
)


# --------------------------------------------------------------------------------
# Huggingface Datasets
# --------------------------------------------------------------------------------
DATASET_STREAMING: bool = False
DATASET_NAME: str = "xsum"
DATASET_TRAIN_NUM_ROWS: int = 204045

train = load_dataset("xsum", split="train", streaming=DATASET_STREAMING)
validation = load_dataset("xsum", split="validation", streaming=DATASET_STREAMING)

# --------------------------------------------------------------------------------
# BLOOM Model
# --------------------------------------------------------------------------------
MODEL = "bigscience/bloom-560m"
MAX_PROMPT_TOKEN_LENGTH: int = 512     # BLOOM token length is 2048
PER_DEVICE_BATCH_SIZE: int = 1

# --------------------------------------------------------------------------------
# PromptSource Template
# --------------------------------------------------------------------------------
prompt_templates = DatasetTemplates( dataset_name=DATASET_NAME)
template: Template = prompt_templates['summarize_DOC']

# --------------------------------------------------------------------------------
# Tokenization
# --------------------------------------------------------------------------------
tokenizer = AutoTokenizer.from_pretrained(MODEL, use_fast=True)


def get_convert_to_request_response(template: Template) -> Callable:
    def _convert_to_prompt_response(example: Dict[str, str]) -> Dict[str, str]:
        """Generate prompt, response as a dictionary:
        {
            "prompt": "Summarize: ...",
            "response": "..."
        }

        NOTE: DO NOT use with dataset map function( batched=True). Use batch=False

        Args:
            example: single {document, summary} pair to be able to apply template
        Returns: a dictionary of pro
        """
        # assert isinstance(example, dict), f"expected dict but {type(example)}.\n{example}"
        assert isinstance(example['document'], str), f"expected str but {type(example['document'])}."
        prompt, response = template.apply(example=example, truncate=False)
        return {
            "prompt": re.sub(r'[\s\'\"]+', ' ', prompt),
            "response": re.sub(r'[\s\'\"]+', ' ', response)
        }

    return _convert_to_prompt_response


convert_to_request_response: Callable = get_convert_to_request_response(template=template)


def tokenize_prompt_response(examples):
    """Generate the model inputs in the dictionary with format:
    {
        "input_ids": List[int], 
        "attention_mask": List[int]",
        "labels": List[int]
    }
    
    Note: Huggngface dataaset map(batched=True, batch_size=n) merges values of 
    n dictionarys into a values of the key. If you have n instances of {"key", "v"}, then
    you will get {"key": ["v", "v", "v", ...] }.
    
    Args:
        examples:   a dictionary of format {
            "prompt": [prompt+],
            "response": [respnse+]
        } where + means more than one instance because of Dataset.map(batched=True)
    """    
    inputs: Dict[str, List[int]] = tokenizer(
        text_target=examples["prompt"], 
        max_length=MAX_PROMPT_TOKEN_LENGTH, 
        truncation=True
    )

    labels: Dict[str, List[int]] = tokenizer(
        text_target=examples["response"], 
        max_length=MAX_PROMPT_TOKEN_LENGTH, 
        truncation=True,
        padding='max_length',
    )
    inputs["labels"] = labels["input_ids"]
    
    return inputs


prompt_response_train = train.map(
    function=convert_to_request_response, 
    batched=False,
    # batch_size=2048,
    # drop_last_batch=False,
    remove_columns=list(train.features.keys()),
)
tokenized_train = prompt_response_train.map(
    function=tokenize_prompt_response, 
    #batched=True,
    batched=False,
    # batch_size=32,
    # drop_last_batch=True,
    remove_columns=['prompt', 'response']
)
del train, prompt_response_train

prompt_response_validation = validation.map(
    function=convert_to_request_response, 
    batched=False,
    # batch_size=2048,
    # drop_last_batch=False,
    remove_columns=list(validation.features.keys()),
)
tokenized_validation = prompt_response_validation.map(
    function=tokenize_prompt_response, 
    #batched=True,
    batched=False,
    # batch_size=32,
    # drop_last_batch=True,
    remove_columns=['prompt', 'response']
)
del validation, prompt_response_validation

tokenized_train.with_format("torch")
tokenized_validation.with_format("torch")


# --------------------------------------------------------------------------------
# Training
# --------------------------------------------------------------------------------
model = BloomForCausalLM.from_pretrained(MODEL)
model.cuda()


def predict(prompt: str) -> str:
    inputs = tokenizer(prompt, return_tensors='pt')
    print(inputs["input_ids"].shape)
    
    response_tokens = model.generate(
        inputs["input_ids"].cuda(), 
        max_new_tokens=1,
        do_sample=False, 
        top_k=50, 
        top_p=0.9
    )[0]
    response = tokenizer.decode(response_tokens, skip_special_tokens=True)
    return response


# DataCollatorWithPadding does not pad 'labels' which causes an error at train()
# https://stackoverflow.com/a/74228547/4281353
data_collator = DataCollatorWithPadding(
    tokenizer=tokenizer, 
    padding='max_length',
    pad_to_multiple_of=8,
    max_length=MAX_PROMPT_TOKEN_LENGTH,
    return_tensors='pt'
)

rouge = evaluate.load("rouge")


def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    result = rouge.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)

    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
    result["gen_len"] = np.mean(prediction_lens)

    return {k: round(v, 4) for k, v in result.items()}


training_args = TrainingArguments(
    output_dir="bloom_finetuned",
    max_steps=DATASET_TRAIN_NUM_ROWS * 3 if DATASET_STREAMING else -1,
    num_train_epochs=4,
    per_device_train_batch_size=PER_DEVICE_BATCH_SIZE,
    per_device_eval_batch_size=PER_DEVICE_BATCH_SIZE,
    learning_rate=2e-5,
    weight_decay=0.01, 
    # fp16=False,
    no_cuda=False,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    save_total_limit=3,
    # log_level="debug",
    disable_tqdm=False,
    push_to_hub=False,
)


trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_validation,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)


trainer.train()
trainer.save_model("finetuned_bloom_model")

At least, the training seems working.

image.png

Sign up or log in to comment