{ "cells": [ { "cell_type": "raw", "id": "42dcf697", "metadata": {}, "source": [ "---\n", "title: 18 Neural Machine Translation using Transformer\n", "description: An implementation of Transformer to translate human readabke dates in any format to YYYY-MM-DD format.\n", "---" ] }, { "cell_type": "markdown", "id": "df9402ac", "metadata": {}, "source": [ "" ] }, { "cell_type": "markdown", "id": "-EKH_E-Z2bBp", "metadata": { "id": "-EKH_E-Z2bBp" }, "source": [ "" ] }, { "cell_type": "markdown", "id": "eea58ad9", "metadata": { "id": "eea58ad9", "papermill": { "duration": 0.01953, "end_time": "2022-04-19T17:09:39.123240", "exception": false, "start_time": "2022-04-19T17:09:39.103710", "status": "completed" }, "tags": [] }, "source": [ "# Neural Machine Translation\n", "\n", "In this notebook we will implement a small transformer model for machine translation task. Our model would be able to translate human readable dates in any format to YYYY-MM-DD format.\n", "\n", "We will be using `faker` module to generate our dataset" ] }, { "cell_type": "code", "execution_count": null, "id": "4f15b8f3", "metadata": { "_cell_guid": "b1076dfc-b9ad-4769-8c92-a6c4dae69d19", "_kg_hide-output": true, "_uuid": "8f2839f25d086af736a60e9eeb907d3b93b6e0e5", "execution": { "iopub.execute_input": "2022-04-19T17:09:39.165791Z", "iopub.status.busy": "2022-04-19T17:09:39.164311Z", "iopub.status.idle": "2022-04-19T17:09:51.020422Z", "shell.execute_reply": "2022-04-19T17:09:51.019834Z", "shell.execute_reply.started": "2022-04-19T16:26:16.961398Z" }, "id": "4f15b8f3", "papermill": { "duration": 11.87833, "end_time": "2022-04-19T17:09:51.020579", "exception": false, "start_time": "2022-04-19T17:09:39.142249", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "%%capture\n", "!pip install -q faker" ] }, { "cell_type": "code", "execution_count": null, "id": "72747262", "metadata": { "execution": { "iopub.execute_input": "2022-04-19T17:09:51.064603Z", "iopub.status.busy": "2022-04-19T17:09:51.063890Z", "iopub.status.idle": "2022-04-19T17:09:51.939467Z", "shell.execute_reply": "2022-04-19T17:09:51.939924Z", "shell.execute_reply.started": "2022-04-19T16:26:27.654284Z" }, "id": "72747262", "papermill": { "duration": 0.899865, "end_time": "2022-04-19T17:09:51.940100", "exception": false, "start_time": "2022-04-19T17:09:51.040235", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "from tqdm.auto import tqdm\n", "import re, random\n", "\n", "from faker import Faker\n", "from babel.dates import format_date\n", "\n", "pd.options.display.max_colwidth = None\n", "sns.set_style('darkgrid')" ] }, { "cell_type": "code", "execution_count": null, "id": "aaa08590", "metadata": { "execution": { "iopub.execute_input": "2022-04-19T17:09:51.986753Z", "iopub.status.busy": "2022-04-19T17:09:51.986001Z", "iopub.status.idle": "2022-04-19T17:09:56.373183Z", "shell.execute_reply": "2022-04-19T17:09:56.372645Z", "shell.execute_reply.started": "2022-04-19T16:26:28.610487Z" }, "id": "aaa08590", "papermill": { "duration": 4.41297, "end_time": "2022-04-19T17:09:56.373332", "exception": false, "start_time": "2022-04-19T17:09:51.960362", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "import tensorflow as tf\n", "from tensorflow import keras\n", "from tensorflow.keras.preprocessing.text import Tokenizer\n", "from tensorflow.keras.preprocessing.sequence import pad_sequences\n", "from tensorflow.keras.models import Model, Sequential\n", "from tensorflow.keras import losses, callbacks, utils, models, Input\n", "from tensorflow.keras import layers as L" ] }, { "cell_type": "markdown", "id": "5aa3910c", "metadata": { "id": "5aa3910c", "papermill": { "duration": 0.019209, "end_time": "2022-04-19T17:09:56.411691", "exception": false, "start_time": "2022-04-19T17:09:56.392482", "status": "completed" }, "tags": [] }, "source": [ "# Data Generation" ] }, { "cell_type": "code", "execution_count": null, "id": "86eeb03e", "metadata": { "execution": { "iopub.execute_input": "2022-04-19T17:09:56.454573Z", "iopub.status.busy": "2022-04-19T17:09:56.453776Z", "iopub.status.idle": "2022-04-19T17:09:56.456276Z", "shell.execute_reply": "2022-04-19T17:09:56.455821Z", "shell.execute_reply.started": "2022-04-19T16:26:33.512965Z" }, "id": "86eeb03e", "papermill": { "duration": 0.026236, "end_time": "2022-04-19T17:09:56.456385", "exception": false, "start_time": "2022-04-19T17:09:56.430149", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "# Constants\n", "class config(): \n", " SAMPLE_SIZE = 10_00_000 \n", " DATE_FORMATS = [\n", " 'short', 'medium', 'long', 'full',\n", " 'd MMM YYY', 'd MMMM YYY', 'dd/MM/YYY',\n", " 'EE d, MMM YYY', 'EEEE d, MMMM YYY', 'd of MMMM YYY',\n", " ]\n", " VALIDATION_SIZE = 0.1\n", " BATCH_SIZE = 32\n", " MAX_EPOCHS = 25\n", " EMBED_DIM = 16\n", " DENSE_DIM = 16\n", " NUM_HEADS = 2\n", " X_LEN = 30\n", " Y_LEN = 14\n", " NUM_ENCODER_TOKENS = 35\n", " NUM_DECODER_TOKENS = 14" ] }, { "cell_type": "code", "execution_count": null, "id": "b0ce1618", "metadata": { "execution": { "iopub.execute_input": "2022-04-19T17:09:56.497409Z", "iopub.status.busy": "2022-04-19T17:09:56.496869Z", "iopub.status.idle": "2022-04-19T17:09:56.550468Z", "shell.execute_reply": "2022-04-19T17:09:56.551032Z", "shell.execute_reply.started": "2022-04-19T16:26:33.526883Z" }, "id": "b0ce1618", "outputId": "5aba41ad-61a3-4126-f7bc-ed7f5c6e3050", "papermill": { "duration": 0.076264, "end_time": "2022-04-19T17:09:56.551205", "exception": false, "start_time": "2022-04-19T17:09:56.474941", "status": "completed" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Sample dates for each format\n", "\n", "d MMMM YYY => 24 October 2007\n", "short => 11/19/75\n", "dd/MM/YYY => 22/12/1982\n", "long => June 11, 2013\n", "d of MMMM YYY => 16 of June 1971\n", "d MMM YYY => 23 Aug 1976\n", "EEEE d, MMMM YYY => Saturday 19, October 2013\n", "EE d, MMM YYY => Sun 12, Nov 2017\n", "full => Friday, May 24, 1996\n", "medium => Jul 24, 2012\n" ] } ], "source": [ "faker = Faker()\n", "print('Sample dates for each format\\n')\n", "\n", "for fmt in set(config.DATE_FORMATS):\n", " print(f'{fmt:20} => {format_date(faker.date_object(), format=fmt, locale=\"en\")}')" ] }, { "cell_type": "code", "execution_count": null, "id": "60475f85", "metadata": { "execution": { "iopub.execute_input": "2022-04-19T17:09:56.597660Z", "iopub.status.busy": "2022-04-19T17:09:56.596963Z", "iopub.status.idle": "2022-04-19T17:09:56.599245Z", "shell.execute_reply": "2022-04-19T17:09:56.599632Z", "shell.execute_reply.started": "2022-04-19T16:31:12.688221Z" }, "id": "60475f85", "papermill": { "duration": 0.028292, "end_time": "2022-04-19T17:09:56.599753", "exception": false, "start_time": "2022-04-19T17:09:56.571461", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "# a utility data cleaning function\n", "def clean_date(raw_date):\n", " return raw_date.lower().replace(',', '')\n", "\n", "# this function will generate our data in a data frame\n", "def create_dataset(num_rows):\n", " dataset = []\n", " \n", " for i in tqdm(range(num_rows)):\n", " dt = faker.date_object()\n", " for fmt in config.DATE_FORMATS:\n", " try:\n", " date = format_date(dt, format=fmt, locale='en')\n", " human_readable = clean_date(date)\n", " machine_readable = f\"@{dt.isoformat()}#\" # adding a start token '@' and end token '#'\n", " except AttributeError as e:\n", " date = None\n", " human_readable = None\n", " machine_readable = None\n", " if human_readable is not None and machine_readable is not None:\n", " dataset.append((human_readable, machine_readable))\n", " \n", " return pd.DataFrame(dataset, columns=['human_readable', 'machine_readable'])" ] }, { "cell_type": "code", "execution_count": null, "id": "6def5add", "metadata": { "colab": { "referenced_widgets": [ "0ac32bd68ce545b59c43e0b6b6e5614d" ] }, "execution": { "iopub.execute_input": "2022-04-19T17:09:56.642243Z", "iopub.status.busy": "2022-04-19T17:09:56.641504Z", "iopub.status.idle": "2022-04-19T17:14:50.497206Z", "shell.execute_reply": "2022-04-19T17:14:50.497826Z", "shell.execute_reply.started": "2022-04-19T16:31:13.00423Z" }, "id": "6def5add", "outputId": "052d7e71-674c-4358-e65c-6e5a43e37b8f", "papermill": { "duration": 293.879505, "end_time": "2022-04-19T17:14:50.498045", "exception": false, "start_time": "2022-04-19T17:09:56.618540", "status": "completed" }, "tags": [] }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "0ac32bd68ce545b59c43e0b6b6e5614d", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/1000000 [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "(187584, 2)\n" ] }, { "data": { "text/html": [ "
\n", " | human_readable | \n", "machine_readable | \n", "
---|---|---|
0 | \n", "06/10/2001 | \n", "@2001-10-06# | \n", "
1 | \n", "sun 14 aug 2005 | \n", "@2005-08-14# | \n", "
2 | \n", "thu 5 oct 1972 | \n", "@1972-10-05# | \n", "
3 | \n", "tuesday 27 october 1970 | \n", "@1970-10-27# | \n", "
4 | \n", "saturday 5 april 2014 | \n", "@2014-04-05# | \n", "