{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "id": "qcv24GSIQE5d" }, "outputs": [], "source": [ "from IPython.display import HTML, display\n", "\n", "def set_css():\n", " display(HTML('''\n", " \n", " '''))\n", "get_ipython().events.register('pre_run_cell', set_css)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "SH8dkqPxQtP7" }, "outputs": [], "source": [ "!pip install --upgrade pip\n", "!pip install transformers\n", "!pip install datasets\n", "!pip install sentencepiece" ] }, { "cell_type": "markdown", "metadata": { "id": "D8hhA8gaQwRR" }, "source": [ "# 📂 Dataset" ] }, { "cell_type": "markdown", "metadata": { "id": "NF-ouJiDQ1FO" }, "source": [ "### Loading the dataset\n", "---" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "moK3d7mTQ1v-" }, "outputs": [], "source": [ "from datasets import load_dataset\n", "\n", "!wget 'https://raw.githubusercontent.com/jamesesguerra/dataset_repo/main/kami-3000.csv'\n", "\n", "dataset = load_dataset('csv', data_files='kami-3000.csv')\n", "\n", "print(dataset)\n", "print()\n", "print(dataset['train'].features)" ] }, { "cell_type": "code", "source": [], "metadata": { "id": "HEWGrOI_VlkN" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "'''USE THIS CODE BLOCK FOR LOCAL INITIALIZATION'''\n", "\n", "from datasets import load_dataset\n", "\n", "dataset = load_dataset('csv', data_files='C:/Users/Public/Documents/hazielle/kami-3000.csv')\n", "\n", "print(dataset)\n", "print()\n", "print(dataset['train'].features)" ], "metadata": { "id": "NgtZQydpwpB-" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "zbxmMmtWRCtX" }, "source": [ "### Filtering rows\n", "---" ] }, { "cell_type": "markdown", "metadata": { "id": "QgoQRt8QREVi" }, "source": [ "**Removing rows with blank article text and blank summary**" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "twzcsfXuRFQQ" }, "outputs": [], "source": [ "dataset = dataset.filter(lambda x: x['article_text'] is not None)\n", "dataset = dataset.filter(lambda x: x['summary'] is not None)\n", "\n", "print(dataset['train'])" ] }, { "cell_type": "markdown", "metadata": { "id": "30Xl1LGoRKkY" }, "source": [ "**Removing rows with `len(article text)` < 25** and **`len(summary)` < 10**\n", "(based on [this paper](http://www.diva-portal.org/smash/get/diva2:1563580/FULLTEXT01.pdf))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "6MjsxAZPRLFk" }, "outputs": [], "source": [ "dataset = dataset.filter(lambda x: len(x['article_text'].split()) > 25)\n", "dataset = dataset.filter(lambda x: len(x['summary'].split()) > 10)\n", "\n", "print(dataset['train'])" ] }, { "cell_type": "markdown", "metadata": { "id": "YLA2bQeNRPAl" }, "source": [ "### Cleaning\n", "---" ] }, { "cell_type": "markdown", "metadata": { "id": "z26t9F1URSCO" }, "source": [ "**Unescaping HTML character codes**" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "BcUTqeFwRQpC" }, "outputs": [], "source": [ "import html\n", "\n", "dataset = dataset.map(\n", " lambda x: {'article_text': [html.unescape(o) for o in x['article_text']]}, batched=True\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "Y9BFM_A-RVdR" }, "source": [ "**Removing unicode hard spaces**" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "D-MJvuTkRY8c" }, "outputs": [], "source": [ "from unicodedata import normalize\n", "\n", "dataset = dataset.map(lambda x: {'article_text': normalize('NFKD', x['article_text'])})" ] }, { "cell_type": "markdown", "metadata": { "id": "6th91MJ3RmJW" }, "source": [ "## Dataset splits\n", "---" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "jVJ--r53RoL6" }, "outputs": [], "source": [ "dataset = dataset['train'].train_test_split(train_size=0.8, seed=42)\n", "\n", "dataset['validation'] = dataset.pop('test')\n", "\n", "print(dataset)" ] }, { "cell_type": "markdown", "metadata": { "id": "UFN9ufDYRp9G" }, "source": [ "# 🪙 Tokenization" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "rP1sC2L0R0HB" }, "outputs": [], "source": [ "from transformers import AutoTokenizer\n", "\n", "checkpoint = \"patrickvonplaten/bert2bert-cnn_dailymail-fp16\"\n", "tokenizer = AutoTokenizer.from_pretrained(checkpoint)" ] }, { "cell_type": "markdown", "metadata": { "id": "1X9Ji15LR8et" }, "source": [ "**Define preprocess function**" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "T1L-Q2v8R93o" }, "outputs": [], "source": [ "# set upper limit on how long the articles and their summaries can be\n", "max_input_length = 512\n", "max_target_length = 128\n", "\n", "def preprocess_function(rows):\n", " model_inputs = tokenizer(rows['article_text'], max_length=max_input_length, truncation=True)\n", " \n", " with tokenizer.as_target_tokenizer():\n", " labels = tokenizer(rows['summary'], max_length=max_target_length, truncation=True)\n", " \n", " model_inputs['labels'] = labels['input_ids']\n", " return model_inputs" ] }, { "cell_type": "markdown", "metadata": { "id": "JEVi769uSARU" }, "source": [ "**Tokenize the dataset**" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "IU5943MESBrK" }, "outputs": [], "source": [ "tokenized_dataset = dataset.map(preprocess_function, batched=True)" ] }, { "cell_type": "markdown", "metadata": { "id": "o8D04VHjSI6b" }, "source": [ "# 📊 Evaluation Metrics" ] }, { "cell_type": "markdown", "metadata": { "id": "2GB7-jfKSMrE" }, "source": [ "## ROUGE\n", "---" ] }, { "cell_type": "markdown", "metadata": { "id": "3TljkwZbSQZV" }, "source": [ "**installing `rouge_score` and loading the metric**" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "HItzZO_mSQG-" }, "outputs": [], "source": [ "!pip install rouge_score" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "7wrZ5kAMSOlH" }, "outputs": [], "source": [ "from datasets import load_metric\n", "rouge_score = load_metric('rouge')" ] }, { "cell_type": "markdown", "metadata": { "id": "tGOAR4SnSeVY" }, "source": [ "## Creating a lead-3 baseline\n", "---" ] }, { "cell_type": "markdown", "metadata": { "id": "3OAa8kIfSgC5" }, "source": [ "**import and download dependencies**" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "x8LFH_0qShRO" }, "outputs": [], "source": [ "!pip install nltk\n", "import nltk\n", "\n", "nltk.download(\"punkt\")" ] }, { "cell_type": "markdown", "metadata": { "id": "WdcIWK8GShzb" }, "source": [ "**define fn to extract the first 3 sentences in an article**" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "17LcLH1FSjtz" }, "outputs": [], "source": [ "from nltk.tokenize import sent_tokenize\n", "\n", "def extract_sentences(text):\n", " return \"\\n\".join(sent_tokenize(text)[:3])\n", "\n", "print(extract_sentences(dataset[\"train\"][4][\"article_text\"]))" ] }, { "cell_type": "markdown", "metadata": { "id": "0aHfU4_tSolA" }, "source": [ "**define fn to extract summaries from the data and compute ROUGE scores for the baseline**" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "08n3A6OGSqK2" }, "outputs": [], "source": [ "def evaluate_baseline(dataset, metric):\n", " summaries = [extract_sentences(text) for text in dataset[\"article_text\"]]\n", " return metric.compute(predictions=summaries, references=dataset[\"summary\"])" ] }, { "cell_type": "markdown", "metadata": { "id": "0fZ67opnSsbe" }, "source": [ "**use fn to compute ROUGE scores over the validation set**" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "nMfTYxxOSwRk" }, "outputs": [], "source": [ "import pandas as pd\n", "\n", "score = evaluate_baseline(dataset[\"validation\"], rouge_score)\n", "rouge_names = [\"rouge1\", \"rouge2\", \"rougeL\", \"rougeLsum\"]\n", "rouge_dict = dict((rn, round(score[rn].mid.fmeasure * 100, 2)) for rn in rouge_names)\n", "print(rouge_dict)" ] }, { "cell_type": "markdown", "metadata": { "id": "tyfkBzlxSyA7" }, "source": [ "# 🔩 Fine-tuning" ] }, { "cell_type": "markdown", "metadata": { "id": "PqlM9-HgS804" }, "source": [ "**Loading the model**" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "R1y2goZ3S-CC" }, "outputs": [], "source": [ "from transformers import EncoderDecoderModel\n", "\n", "model = EncoderDecoderModel.from_pretrained(checkpoint, pad_token_id=0)\n" ] }, { "cell_type": "markdown", "metadata": { "id": "MMsjH4Z6TA73" }, "source": [ "**Logging in Hugging Face Hub**" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "BLSPzmoBTCLk" }, "outputs": [], "source": [ "from huggingface_hub import notebook_login\n", "notebook_login()" ] }, { "cell_type": "markdown", "metadata": { "id": "IHH0nuznTD2L" }, "source": [ "**set up hyperparameters for training**" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "CCEGxd76TEff" }, "outputs": [], "source": [ "from transformers import Seq2SeqTrainingArguments\n", "\n", "batch_size = 4\n", "num_train_epochs = 2\n", "logging_steps = len(tokenized_dataset['train']) // batch_size\n", "model_name = checkpoint.split('/')[-1]\n", "\n", "args = Seq2SeqTrainingArguments(\n", " output_dir=f\"{model_name}-finetuned-1.0.0\",\n", " evaluation_strategy=\"epoch\",\n", " learning_rate=5e-5,\n", " per_device_train_batch_size=batch_size,\n", " per_device_eval_batch_size=batch_size,\n", " weight_decay=0.01,\n", " save_total_limit=3,\n", " num_train_epochs=num_train_epochs,\n", " predict_with_generate=True,\n", " logging_steps=logging_steps,\n", " push_to_hub=True,\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "PdY0ecY9THT8" }, "source": [ "**define fn to evaluate model during training**" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "d6DqOp4ITKGs" }, "outputs": [], "source": [ "import numpy as np\n", "\n", "\n", "def compute_metrics(eval_pred):\n", " predictions, labels = eval_pred\n", " decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)\n", " labels = np.where(labels != -100, labels, tokenizer.pad_token_id)\n", " decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)\n", " decoded_preds = [\"\\n\".join(sent_tokenize(pred.strip())) for pred in decoded_preds]\n", " decoded_labels = [\"\\n\".join(sent_tokenize(label.strip())) for label in decoded_labels]\n", " result = rouge_score.compute(\n", " predictions=decoded_preds, references=decoded_labels, use_stemmer=True\n", " )\n", " result = {key: value.mid.fmeasure * 100 for key, value in result.items()}\n", " return {k: round(v, 4) for k, v in result.items()}" ] }, { "cell_type": "markdown", "metadata": { "id": "y_wEoWIWTMjr" }, "source": [ "**define data collator for dynamic padding**" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ThUqaIr2TPh4" }, "outputs": [], "source": [ "from transformers import DataCollatorForSeq2Seq\n", "\n", "data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)" ] }, { "cell_type": "markdown", "metadata": { "id": "v_Q4XoW7UaTi" }, "source": [ "**instantiate trainer with arguments**" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "zkCyYVTdUbE7" }, "outputs": [], "source": [ "from transformers import Seq2SeqTrainer\n", "\n", "trainer = Seq2SeqTrainer(\n", " model,\n", " args,\n", " train_dataset=tokenized_dataset[\"train\"],\n", " eval_dataset=tokenized_dataset[\"validation\"],\n", " data_collator=data_collator,\n", " tokenizer=tokenizer,\n", " compute_metrics=compute_metrics,\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "Ksa_utSpUnO6" }, "source": [ "**launch training run**" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "YBGhf1xYUp7B" }, "outputs": [], "source": [ "trainer.train()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "YCwxQuydUI7K" }, "outputs": [], "source": [ "trainer.evaluate()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "4eNoOqM2rWw1" }, "outputs": [], "source": [ "trainer.push_to_hub()" ] } ], "metadata": { "accelerator": "GPU", "colab": { "provenance": [] }, "gpuClass": "standard", "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 0 }