{ "cells": [ { "cell_type": "markdown", "id": "c08e675e-437e-4e7d-baee-bd55dda74611", "metadata": {}, "source": [ "# Abstractive Text Summarization with T5\n", "\n", "This implementation uses HuggingFace, especially utilizing `AutoModelForSeq2SeqLM` and `AutoTokenizer`. " ] }, { "cell_type": "markdown", "id": "a910e4b5-040d-4499-b5c2-32f3e1ac1c34", "metadata": {}, "source": [ "## Importing libraries" ] }, { "cell_type": "code", "execution_count": 1, "id": "d22ee5a9-1981-4883-a926-db37905ec8b6", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Setup done!\n" ] } ], "source": [ "# Installs\n", "!pip install -q evaluate py7zr rouge_score absl-py\n", "\n", "# Imports here\n", "import numpy as np\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "import nltk\n", "from nltk.tokenize import sent_tokenize\n", "nltk.download(\"punkt\")\n", "\n", "import torch\n", "import torch.nn as nn\n", "\n", "import datasets\n", "import transformers\n", "from transformers import (\n", " AutoModelForSeq2SeqLM,\n", " Seq2SeqTrainingArguments,\n", " Seq2SeqTrainer,\n", " AutoTokenizer\n", ")\n", "import evaluate\n", "\n", "# Quality of life fixes\n", "import warnings\n", "warnings.filterwarnings('ignore')\n", "from pprint import pprint\n", "\n", "import os\n", "os.environ[\"WANDB_DISABLED\"] = \"true\"\n", "\n", "from IPython.display import clear_output\n", "\n", "print(f\"PyTorch version: {torch.__version__}\")\n", "print(f\"Transformers version: {transformers.__version__}\")\n", "print(f\"Datasets version: {datasets.__version__}\")\n", "print(f\"Evaluate version: {evaluate.__version__}\")\n", "\n", "# Get the samsum dataset\n", "samsum = datasets.load_dataset('samsum')\n", "clear_output()\n", "print(\"Setup done!\")" ] }, { "cell_type": "code", "execution_count": 2, "id": "bafa753c-0746-4ece-b5eb-4511c9138b09", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'4.27.4'" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Verify transformers version\n", "transformers.__version__" ] }, { "cell_type": "markdown", "id": "f15204cc-0f21-4dc9-a8e4-429c57b227a9", "metadata": {}, "source": [ "## Playing around with the dataset" ] }, { "cell_type": "code", "execution_count": 3, "id": "ba5c1425-a776-4201-97e2-bd420ec112fe", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "DatasetDict({\n", " train: Dataset({\n", " features: ['id', 'dialogue', 'summary'],\n", " num_rows: 14732\n", " })\n", " test: Dataset({\n", " features: ['id', 'dialogue', 'summary'],\n", " num_rows: 819\n", " })\n", " validation: Dataset({\n", " features: ['id', 'dialogue', 'summary'],\n", " num_rows: 818\n", " })\n", "})" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# The samsum dataset shape\n", "samsum" ] }, { "cell_type": "code", "execution_count": 4, "id": "5d53736c-a8c7-4fe3-b8f1-566c1d99162b", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Dialogue:\n", "Ollie: How is your Hebrew?\r\n", "Gabi: Not great. \r\n", "Ollie: Could you translate a letter?\r\n", "Gabi: From Hebrew to English maybe, the opposite I don’t think so\r\n", "Gabi: My writing sucks\r\n", "Ollie: Please help me. I don’t have anyone else to ask\r\n", "Gabi: Send it to me. I’ll try. \n", "\n", " -------------------------------------------------- \n", "\n", "Summary:\n", "Gabi knows a bit of Hebrew, though her writing isn't great. She will try to help Ollie translate a letter.\n" ] } ], "source": [ "rand_idx = np.random.randint(0, len(samsum['train']))\n", "\n", "print(f\"Dialogue:\\n{samsum['train'][rand_idx]['dialogue']}\")\n", "print('\\n', '-'*50, '\\n')\n", "print(f\"Summary:\\n{samsum['train'][rand_idx]['summary']}\")" ] }, { "cell_type": "markdown", "id": "8f95359e-c9c4-4ed5-9130-5e2b4a0a83ad", "metadata": {}, "source": [ "## Preprocessing data" ] }, { "cell_type": "markdown", "id": "50b572e6-b37a-4688-94c9-9c45a2c67c51", "metadata": {}, "source": [ " I'm using the T5 Transformers model (Text-to-Text Transfer Transformer)" ] }, { "cell_type": "code", "execution_count": 5, "id": "13634dfe-5b1a-4515-9476-8ac0637d0362", "metadata": {}, "outputs": [], "source": [ "model_ckpt = 't5-small'\n", "\n", "# TODO: Create the Tokenizer AutoTokenizer pretrained checkpoint\n", "tokenizer = AutoTokenizer.from_pretrained('t5-small')" ] }, { "cell_type": "code", "execution_count": 6, "id": "6b0be9fc-029b-4057-9d08-29235e5b4573", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Loading cached processed dataset at C:\\Users\\QXLVR\\.cache\\huggingface\\datasets\\samsum\\samsum\\0.0.0\\f1d7c6b7353e6de335d444e424dc002ef70d1277109031327bc9cc6af5d3d46e\\cache-78c13bd5dd6a016a.arrow\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Max source length: 512\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Map: 0%| | 0/15551 [00:00, ? examples/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Max target length: 95\n" ] } ], "source": [ "from datasets import concatenate_datasets\n", "# Find the max lengths of the source and target samples\n", "# The maximum total input sequence length after tokenization. \n", "# Sequences that are longer than this will be truncated, sequences shorter are be padded.\n", "tokenized_inputs = concatenate_datasets([samsum[\"train\"], samsum[\"test\"]]).map(lambda x: tokenizer(x[\"dialogue\"], truncation=True), batched=True, remove_columns=[\"dialogue\", \"summary\"])\n", "max_source_length = max([len(x) for x in tokenized_inputs[\"input_ids\"]])\n", "print(f\"Max source length: {max_source_length}\")\n", "\n", "# The maximum total sequence length for target text after tokenization. \n", "# Sequences that are longer than this will be truncated, sequences shorter are be padded.\n", "tokenized_targets = concatenate_datasets([samsum[\"train\"], samsum[\"test\"]]).map(lambda x: tokenizer(x[\"summary\"], truncation=True), batched=True, remove_columns=[\"dialogue\", \"summary\"])\n", "max_target_length = max([len(x) for x in tokenized_targets[\"input_ids\"]])\n", "print(f\"Max target length: {max_target_length}\")" ] }, { "cell_type": "code", "execution_count": 7, "id": "c43b0864-8b92-4cb9-b159-bc8ec15bcc2d", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Loading cached processed dataset at C:\\Users\\QXLVR\\.cache\\huggingface\\datasets\\samsum\\samsum\\0.0.0\\f1d7c6b7353e6de335d444e424dc002ef70d1277109031327bc9cc6af5d3d46e\\cache-073bbcc8f496f07c.arrow\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Map: 0%| | 0/819 [00:00, ? examples/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "Loading cached processed dataset at C:\\Users\\QXLVR\\.cache\\huggingface\\datasets\\samsum\\samsum\\0.0.0\\f1d7c6b7353e6de335d444e424dc002ef70d1277109031327bc9cc6af5d3d46e\\cache-a43b31cabc78c9c3.arrow\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Keys of tokenized dataset: ['input_ids', 'attention_mask', 'labels']\n" ] } ], "source": [ "def preprocess_function(\n", " sample, \n", " padding=\"max_length\", \n", " max_source_length=max_source_length,\n", " max_target_length=max_target_length\n", "):\n", " '''\n", " A preprocessing function that will be applied across the dataset.\n", " The inputs and targets will be tokenized and padded/truncated to the max lengths.\n", "\n", " Args:\n", " sample: A dictionary containing the source and target texts (keys are \"dialogue\" and \"summary\") in a list.\n", " padding: Whether to pad the inputs and targets to the max lengths.\n", " max_source_length: The maximum length of the source text.\n", " max_target_length: The maximum length of the target text.\n", " '''\n", " # Add prefix to the input for t5\n", " inputs = ['summarize: ' + s for s in sample['dialogue']]\n", " \n", " # Tokenize inputs, specifying the padding, truncation and max_length\n", " model_inputs = tokenizer(inputs, max_length=max_source_length, padding=padding, truncation=True)\n", "\n", " # Tokenize targets with the `text_target` keyword argument\n", " labels = tokenizer(text_target=sample['summary'], max_length=max_target_length, padding=padding, truncation=True)\n", "\n", " # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore padding in the loss\n", " if padding == \"max_length\":\n", " labels[\"input_ids\"] = [\n", " [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels[\"input_ids\"]\n", " ]\n", "\n", " # Format and return\n", " model_inputs[\"labels\"] = labels[\"input_ids\"]\n", " return model_inputs\n", "\n", "# Map this preprocessing function to our datasets using .map on the samsum variable\n", "tokenized_dataset = samsum.map(preprocess_function, batched=True, remove_columns=[\"dialogue\", \"summary\", \"id\"])\n", "print(f\"Keys of tokenized dataset: {list(tokenized_dataset['train'].features)}\")" ] }, { "cell_type": "code", "execution_count": 8, "id": "3becd236-0097-4ae5-9bd6-a91ed332e748", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "DatasetDict({\n", " train: Dataset({\n", " features: ['input_ids', 'attention_mask', 'labels'],\n", " num_rows: 14732\n", " })\n", " test: Dataset({\n", " features: ['input_ids', 'attention_mask', 'labels'],\n", " num_rows: 819\n", " })\n", " validation: Dataset({\n", " features: ['input_ids', 'attention_mask', 'labels'],\n", " num_rows: 818\n", " })\n", "})" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tokenized_dataset" ] }, { "cell_type": "code", "execution_count": 9, "id": "20110839-bb02-4d64-8de7-53253e3f7fe0", "metadata": {}, "outputs": [], "source": [ "metric = evaluate.load(\"rouge\")\n", "clear_output()" ] }, { "cell_type": "code", "execution_count": 10, "id": "ca00f91d-8453-4496-a064-525ef437198f", "metadata": {}, "outputs": [], "source": [ "def postprocess_text(preds, labels):\n", " '''\n", " A simple post-processing function to clean up the predictions and labels\n", "\n", " Args:\n", " preds: List[str] of predictions\n", " labels: List[str] of labels\n", " '''\n", " \n", " # strip whitespace on all sentences in preds and labels\n", " preds = [p.strip(' ') for p in preds]\n", " labels = [l.strip(' ') for l in preds]\n", " \n", " # rougeLSum expects newline after each sentence\n", " preds = [\"\\n\".join(sent_tokenize(pred)) for pred in preds]\n", " labels = [\"\\n\".join(sent_tokenize(label)) for label in labels]\n", "\n", " return preds, labels\n", "\n", "def compute_metrics(eval_preds):\n", " \n", " # Fetch the predictions and labels\n", " preds, labels = eval_preds\n", " if isinstance(preds, tuple):\n", " preds = preds[0]\n", " \n", " # Decode the predictions back to text\n", " decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)\n", " \n", " # Replace -100 in the labels as we can't decode them.\n", " labels = np.where(labels != -100, labels, tokenizer.pad_token_id)\n", " decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)\n", "\n", " # Some simple post-processing for ROUGE\n", " decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)\n", "\n", " # Compute ROUGE on the decoded predictions and the decoder labels\n", " result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)\n", " \n", " result = {k: round(v * 100, 4) for k, v in result.items()}\n", " prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]\n", " result[\"gen_len\"] = np.mean(prediction_lens)\n", " return result" ] }, { "cell_type": "markdown", "id": "7b244846-2ebf-4019-a577-3ef07e350f7c", "metadata": {}, "source": [ "## Creating the model" ] }, { "cell_type": "code", "execution_count": 11, "id": "49c1ac7c-6400-4a67-b32b-5bdc7330d790", "metadata": {}, "outputs": [], "source": [ "# the AutoModelForSeq2SeqLM class and use the model_ckpt variable)\n", "model = AutoModelForSeq2SeqLM.from_pretrained(model_ckpt)\n", "\n", "clear_output()" ] }, { "cell_type": "code", "execution_count": 12, "id": "e027b290-c04f-4241-b238-41787f32abe0", "metadata": {}, "outputs": [], "source": [ "# we want to ignore tokenizer pad token in the loss\n", "label_pad_token_id = -100\n", "\n", "# Data Collator, specifying the tokenizer, model, and label_pad_token_id\n", "# pad_to_multiple_of=8 to speed up training\n", "data_collator = transformers.DataCollatorForSeq2Seq(\n", " tokenizer,\n", " model=model,\n", " label_pad_token_id=label_pad_token_id,\n", " pad_to_multiple_of=8\n", ")" ] }, { "cell_type": "code", "execution_count": 13, "id": "0d20ee86-ac8c-4ae7-9e7c-92283e879e00", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).\n" ] } ], "source": [ "import logging\n", "logging.getLogger(\"transformers\").setLevel(logging.WARNING)\n", "\n", "\n", "# Define training hyperparameters in Seq2SeqTrainingArguments\n", "training_args = Seq2SeqTrainingArguments(\n", " output_dir=\"./t5_samsum\", # the output directory\n", " logging_strategy=\"epoch\",\n", " save_strategy=\"epoch\",\n", " evaluation_strategy=\"epoch\",\n", " learning_rate=2e-5,\n", " num_train_epochs=5,\n", " predict_with_generate=True,\n", " per_device_train_batch_size=8,\n", " per_device_eval_batch_size=8,\n", " weight_decay=0.01,\n", " load_best_model_at_end=True,\n", " logging_steps=50,\n", " logging_first_step=False,\n", " fp16=False\n", ")\n", "\n", "# index into the tokenized_dataset variable to get the training and validation data\n", "training_data = tokenized_dataset['train']\n", "eval_data = tokenized_dataset['validation']\n", "\n", "# Create the Trainer for the model\n", "trainer = Seq2SeqTrainer(\n", " model=model, # the model to be trained\n", " args=training_args, # training arguments\n", " train_dataset=training_data, # the training dataset\n", " eval_dataset=eval_data, # the validation dataset\n", " tokenizer=tokenizer, # the tokenizer we used to tokenize our data\n", " compute_metrics=compute_metrics, # the function we defined above to compute metrics\n", " data_collator=data_collator # the data collator we defined above\n", ")" ] }, { "cell_type": "code", "execution_count": 14, "id": "a3b5f21d-b4cb-4f8b-a7fc-cf132ef43c65", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "TrainOutput(global_step=9210, training_loss=1.9861197174436753, metrics={'train_runtime': 3551.1547, 'train_samples_per_second': 20.743, 'train_steps_per_second': 2.594, 'total_flos': 9969277096427520.0, 'train_loss': 1.9861197174436753, 'epoch': 5.0})\n" ] } ], "source": [ "# Train the model (this will take a while!)\n", "results = trainer.train()\n", "clear_output()\n", "pprint(results)" ] }, { "cell_type": "markdown", "id": "ddf8c308", "metadata": {}, "source": [ "## Evaluating the model" ] }, { "cell_type": "code", "execution_count": 15, "id": "03e94a7f-2d26-48eb-ab17-cb58b14b93f3", "metadata": {}, "outputs": [], "source": [ "res = trainer.evaluate()\n", "clear_output()" ] }, { "cell_type": "code", "execution_count": 18, "id": "23675ccb-071c-4a4f-8e42-1a71dc628a5c", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | eval_loss | \n", "eval_rouge1 | \n", "eval_rouge2 | \n", "eval_rougeL | \n", "eval_rougeLsum | \n", "
---|---|---|---|---|---|
t5-small | \n", "1.764253 | \n", "100.0 | \n", "100.0 | \n", "100.0 | \n", "100.0 | \n", "