{ "cells": [ { "cell_type": "markdown", "id": "c08e675e-437e-4e7d-baee-bd55dda74611", "metadata": {}, "source": [ "# Abstractive Text Summarization with T5\n", "\n", "This implementation uses HuggingFace, especially utilizing `AutoModelForSeq2SeqLM` and `AutoTokenizer`. " ] }, { "cell_type": "markdown", "id": "a910e4b5-040d-4499-b5c2-32f3e1ac1c34", "metadata": {}, "source": [ "## Importing libraries" ] }, { "cell_type": "code", "execution_count": 1, "id": "d22ee5a9-1981-4883-a926-db37905ec8b6", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Setup done!\n" ] } ], "source": [ "# Installs\n", "!pip install -q evaluate py7zr rouge_score absl-py\n", "\n", "# Imports here\n", "import numpy as np\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "import nltk\n", "from nltk.tokenize import sent_tokenize\n", "nltk.download(\"punkt\")\n", "\n", "import torch\n", "import torch.nn as nn\n", "\n", "import datasets\n", "import transformers\n", "from transformers import (\n", " AutoModelForSeq2SeqLM,\n", " Seq2SeqTrainingArguments,\n", " Seq2SeqTrainer,\n", " AutoTokenizer\n", ")\n", "import evaluate\n", "\n", "# Quality of life fixes\n", "import warnings\n", "warnings.filterwarnings('ignore')\n", "from pprint import pprint\n", "\n", "import os\n", "os.environ[\"WANDB_DISABLED\"] = \"true\"\n", "\n", "from IPython.display import clear_output\n", "\n", "print(f\"PyTorch version: {torch.__version__}\")\n", "print(f\"Transformers version: {transformers.__version__}\")\n", "print(f\"Datasets version: {datasets.__version__}\")\n", "print(f\"Evaluate version: {evaluate.__version__}\")\n", "\n", "# Get the samsum dataset\n", "samsum = datasets.load_dataset('samsum')\n", "clear_output()\n", "print(\"Setup done!\")" ] }, { "cell_type": "code", "execution_count": 2, "id": "bafa753c-0746-4ece-b5eb-4511c9138b09", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'4.27.4'" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Verify transformers version\n", "transformers.__version__" ] }, { "cell_type": "markdown", "id": "f15204cc-0f21-4dc9-a8e4-429c57b227a9", "metadata": {}, "source": [ "## Playing around with the dataset" ] }, { "cell_type": "code", "execution_count": 3, "id": "ba5c1425-a776-4201-97e2-bd420ec112fe", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "DatasetDict({\n", " train: Dataset({\n", " features: ['id', 'dialogue', 'summary'],\n", " num_rows: 14732\n", " })\n", " test: Dataset({\n", " features: ['id', 'dialogue', 'summary'],\n", " num_rows: 819\n", " })\n", " validation: Dataset({\n", " features: ['id', 'dialogue', 'summary'],\n", " num_rows: 818\n", " })\n", "})" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# The samsum dataset shape\n", "samsum" ] }, { "cell_type": "code", "execution_count": 4, "id": "5d53736c-a8c7-4fe3-b8f1-566c1d99162b", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Dialogue:\n", "Ollie: How is your Hebrew?\r\n", "Gabi: Not great. \r\n", "Ollie: Could you translate a letter?\r\n", "Gabi: From Hebrew to English maybe, the opposite I don’t think so\r\n", "Gabi: My writing sucks\r\n", "Ollie: Please help me. I don’t have anyone else to ask\r\n", "Gabi: Send it to me. I’ll try. \n", "\n", " -------------------------------------------------- \n", "\n", "Summary:\n", "Gabi knows a bit of Hebrew, though her writing isn't great. She will try to help Ollie translate a letter.\n" ] } ], "source": [ "rand_idx = np.random.randint(0, len(samsum['train']))\n", "\n", "print(f\"Dialogue:\\n{samsum['train'][rand_idx]['dialogue']}\")\n", "print('\\n', '-'*50, '\\n')\n", "print(f\"Summary:\\n{samsum['train'][rand_idx]['summary']}\")" ] }, { "cell_type": "markdown", "id": "8f95359e-c9c4-4ed5-9130-5e2b4a0a83ad", "metadata": {}, "source": [ "## Preprocessing data" ] }, { "cell_type": "markdown", "id": "50b572e6-b37a-4688-94c9-9c45a2c67c51", "metadata": {}, "source": [ " I'm using the T5 Transformers model (Text-to-Text Transfer Transformer)" ] }, { "cell_type": "code", "execution_count": 5, "id": "13634dfe-5b1a-4515-9476-8ac0637d0362", "metadata": {}, "outputs": [], "source": [ "model_ckpt = 't5-small'\n", "\n", "# TODO: Create the Tokenizer AutoTokenizer pretrained checkpoint\n", "tokenizer = AutoTokenizer.from_pretrained('t5-small')" ] }, { "cell_type": "code", "execution_count": 6, "id": "6b0be9fc-029b-4057-9d08-29235e5b4573", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Loading cached processed dataset at C:\\Users\\QXLVR\\.cache\\huggingface\\datasets\\samsum\\samsum\\0.0.0\\f1d7c6b7353e6de335d444e424dc002ef70d1277109031327bc9cc6af5d3d46e\\cache-78c13bd5dd6a016a.arrow\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Max source length: 512\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Map: 0%| | 0/15551 [00:00\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
eval_losseval_rouge1eval_rouge2eval_rougeLeval_rougeLsum
t5-small1.764253100.0100.0100.0100.0
\n", "" ], "text/plain": [ " eval_loss eval_rouge1 eval_rouge2 eval_rougeL eval_rougeLsum\n", "t5-small 1.764253 100.0 100.0 100.0 100.0" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "cols = [\"eval_loss\", \"eval_rouge1\", \"eval_rouge2\", \"eval_rougeL\", \"eval_rougeLsum\"]\n", "filtered_scores = dict((x , res[x]) for x in cols)\n", "pd.DataFrame([filtered_scores], index=[model_ckpt])" ] }, { "cell_type": "code", "execution_count": 20, "id": "7c59a731", "metadata": {}, "outputs": [], "source": [ "from transformers import pipeline\n", "\n", "summarizer_pipeline = pipeline(\"summarization\",\n", " model=model,\n", " tokenizer=tokenizer,\n", " device=0)" ] }, { "cell_type": "code", "execution_count": 22, "id": "5138f2bc", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Dialogue: Adelina: Hi handsome. Where you you come from?\r\n", "Cyprien: What do you mean?\r\n", "Adelina: What do you mean, \"what do you mean\"? It's a simple question, where do you come from?\r\n", "Cyprien: Well I was born in Jarrow, live in London now, so you could say I came from either of those places\r\n", "Cyprien: I was educated in Loughborouogh, so in a sense I came from there.\r\n", "Adelina: OK. \r\n", "Cyprien: In another sense I come from my mother's vagina, but I dare say everyone can say that.\r\n", "Adelina: Are you all right?\r\n", "Cyprien: IN another sense I come from the atoms in the air that I breath or the food I eat, which comes to me from many places, so all I can say is \"I come from Planet Earth\".\r\n", "Adelina: OK, bye. If you're gonna be a dick...\r\n", "Cyprien: Wait, what you got against earthlings?\n", "-------------------------\n", "True Summary: Cyprien irritates Adelina by giving too many responses.\n", "-------------------------\n", "Model Summary: Cyprien came from Jarrow, live in London. She came from Loughborouogh, and came from her mother's vagina.\n", "-------------------------\n" ] } ], "source": [ "rand_idx = np.random.randint(low=0, high=len(samsum[\"test\"]))\n", "sample = samsum[\"test\"][rand_idx]\n", "\n", "dialog = sample[\"dialogue\"]\n", "true_summary = sample[\"summary\"]\n", "\n", "model_summary = summarizer_pipeline(dialog)\n", "clear_output()\n", "\n", "print(f\"Dialogue: {dialog}\")\n", "print(\"-\"*25)\n", "print(f\"True Summary: {true_summary}\")\n", "print(\"-\"*25)\n", "print(f\"Model Summary: {model_summary[0]['summary_text']}\")\n", "print(\"-\"*25)" ] }, { "cell_type": "code", "execution_count": 24, "id": "f051655f", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Your max_length is set to 200, but you input_length is only 94. You might consider decreasing max_length manually, e.g. summarizer('...', max_length=47)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Original Text:\n", "\n", "Andy: I need you to come in to work on the weekend.\n", "David: Why boss? I have plans to go on a concert I might not be able to come on the weekend.\n", "Andy: It's important we need to get our paperwork all sorted out for this year. Corporate needs it.\n", "David: But I already made plans and this is news to me on very short notice.\n", "Andy: Be there or you'r fired\n", "\n", "\n", " -------------------------------------------------- \n", "\n", "Generated Summary: \n", "[{'summary_text': 'David has plans to go on a concert. Andy needs to get his paperwork all sorted out for this year. David already made plans.'}]\n" ] } ], "source": [ "def create_summary(input_text, model_pipeline=summarizer_pipeline):\n", " summary = model_pipeline(input_text)\n", " return summary\n", "\n", "text = '''\n", "Andy: I need you to come in to work on the weekend.\n", "David: Why boss? I have plans to go on a concert I might not be able to come on the weekend.\n", "Andy: It's important we need to get our paperwork all sorted out for this year. Corporate needs it.\n", "David: But I already made plans and this is news to me on very short notice.\n", "Andy: Be there or you'r fired\n", "'''\n", "\n", "print(f\"Original Text:\\n{text}\")\n", "print('\\n', '-'*50, '\\n')\n", "\n", "summary = create_summary(text)\n", "\n", "print(f\"Generated Summary: \\n{summary}\")" ] }, { "cell_type": "code", "execution_count": null, "id": "ad5d29a0", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.0" } }, "nbformat": 4, "nbformat_minor": 5 }