{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 54,
   "id": "1326a46f-6f88-47b3-9dfb-28ada9cc39d6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/bin/bash: /home/andy/miniconda3/envs/tf/lib/libtinfo.so.6: no version information available (required by /bin/bash)\n",
      "Collecting pandas\n",
      "  Obtaining dependency information for pandas from https://files.pythonhosted.org/packages/9e/0d/91a9fd2c202f2b1d97a38ab591890f86480ecbb596cbc56d035f6f23fdcc/pandas-2.0.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata\n",
      "  Downloading pandas-2.0.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (18 kB)\n",
      "Requirement already satisfied: python-dateutil>=2.8.2 in /home/andy/miniconda3/envs/tf/lib/python3.9/site-packages (from pandas) (2.8.2)\n",
      "Collecting pytz>=2020.1 (from pandas)\n",
      "  Downloading pytz-2023.3-py2.py3-none-any.whl (502 kB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m502.3/502.3 kB\u001b[0m \u001b[31m4.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0m\n",
      "\u001b[?25hCollecting tzdata>=2022.1 (from pandas)\n",
      "  Downloading tzdata-2023.3-py2.py3-none-any.whl (341 kB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m341.8/341.8 kB\u001b[0m \u001b[31m4.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0m\n",
      "\u001b[?25hRequirement already satisfied: numpy>=1.20.3 in /home/andy/miniconda3/envs/tf/lib/python3.9/site-packages (from pandas) (1.24.3)\n",
      "Requirement already satisfied: six>=1.5 in /home/andy/miniconda3/envs/tf/lib/python3.9/site-packages (from python-dateutil>=2.8.2->pandas) (1.16.0)\n",
      "Downloading pandas-2.0.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (12.4 MB)\n",
      "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m12.4/12.4 MB\u001b[0m \u001b[31m4.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n",
      "\u001b[?25hInstalling collected packages: pytz, tzdata, pandas\n",
      "Successfully installed pandas-2.0.3 pytz-2023.3 tzdata-2023.3\n"
     ]
    }
   ],
   "source": [
    "!pip install pandas"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3c327752-a57d-4bde-bf57-796302eb9468",
   "metadata": {},
   "source": [
    "## Import"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "1de2223f-1631-4323-bedf-d52ee85c302d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "os.environ[\"TF_CPP_MIN_LOG_LEVEL\"] = \"3\"\n",
    "# Set the number of CPU cores for TensorFlow operations using environment variables\n",
    "os.environ[\"TF_NUM_INTRAOP_THREADS\"] = \"4\"\n",
    "os.environ[\"TF_NUM_INTEROP_THREADS\"] = \"4\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "22cf7c2a-ae01-4192-b32d-1b539f6726b4",
   "metadata": {},
   "outputs": [],
   "source": [
    "import difflib\n",
    "import random\n",
    "import re\n",
    "\n",
    "# Set mixed precision policy\n",
    "import keras_tuner as kt\n",
    "import matplotlib.pyplot as plt\n",
    "import nltk\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import pysnooper\n",
    "\n",
    "# print(tf.config.list_physical_devices('GPU'))\n",
    "import tensorflow as tf\n",
    "from nltk.corpus import words\n",
    "from tensorflow.keras import mixed_precision\n",
    "from tensorflow.keras.callbacks import LearningRateScheduler, TensorBoard\n",
    "from tensorflow.keras.layers import LSTM, Dense, SimpleRNN\n",
    "from tensorflow.keras.models import Sequential, load_model\n",
    "from tensorflow.keras.optimizers import Adam\n",
    "from tensorflow.keras.optimizers.schedules import ExponentialDecay\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "7576ee8f-2c2d-4ab6-80d0-21229343b035",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[nltk_data] Downloading package words to\n",
      "[nltk_data]     C:\\Users\\i5\\AppData\\Roaming\\nltk_data...\n",
      "[nltk_data]   Unzipping corpora\\words.zip.\n"
     ]
    }
   ],
   "source": [
    "nltk.download(\"words\")\n",
    "\n",
    "# Get the list of English words\n",
    "english_words = words.words(\"en\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fec9afe8-6d3f-4329-ae7b-47d0fcd417b5",
   "metadata": {},
   "source": [
    "## Setup Tensor Flow and Check GPU recognised"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "07d8f2bd-8202-48af-b76b-a079b23f1ebe",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]\n",
      "INFO:tensorflow:Mixed precision compatibility check (mixed_float16): OK\n",
      "Your GPU will likely run quickly with dtype policy mixed_float16 as it has compute capability of at least 7.0. Your GPU: NVIDIA GeForce RTX 2070, compute capability 7.5\n",
      "Number of CPU cores configured for TensorFlow: 4\n"
     ]
    }
   ],
   "source": [
    "# Clear GPU memory explicitly\n",
    "tf.keras.backend.clear_session()\n",
    "# Limit GPU memory growth\n",
    "gpus = tf.config.experimental.list_physical_devices(\"GPU\")\n",
    "print(gpus)\n",
    "if gpus:\n",
    "    for gpu in gpus:\n",
    "        tf.config.experimental.set_memory_growth(gpu, True)\n",
    "policy = mixed_precision.Policy(\"mixed_float16\")\n",
    "mixed_precision.set_global_policy(policy)\n",
    "\n",
    "# Create a TensorBoard callback\n",
    "log_dir = \"logs/\"  # Directory to store the log files\n",
    "tensorboard_callback = TensorBoard(log_dir=log_dir, histogram_freq=1)\n",
    "# Check the number of CPU cores TensorFlow is using (this should reflect the environment variable settings)\n",
    "print(\n",
    "    f\"Number of CPU cores configured for TensorFlow: {os.environ['TF_NUM_INTRAOP_THREADS']}\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dcc670f9-7450-4c70-acd3-797bce963f30",
   "metadata": {},
   "source": [
    "# Functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "e1b228ed-c7f8-40db-86ae-96fc4a64e8dc",
   "metadata": {
    "editable": true,
    "slideshow": {
     "slide_type": ""
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "def test_word_in_corpus(word, corpus):\n",
    "    return word in corpus\n",
    "\n",
    "\n",
    "def replace_with_similar_word(word, corpus):\n",
    "    words = difflib.get_close_matches(word, corpus)\n",
    "    if len(words) == 0:\n",
    "        word = word\n",
    "    else:\n",
    "        word = np.random.choice(words)\n",
    "    return word\n",
    "\n",
    "\n",
    "def refine_words(words):\n",
    "    \"\"\"takes made-up words and creates more pronouncable ones based on parts of the words being in a lexicon\"\"\"\n",
    "    new_words = []\n",
    "    for w in tqdm(words):\n",
    "        for i in range(2, len(w)):\n",
    "            word_length = len(w)\n",
    "            if i == 0:\n",
    "                suffix_word = w\n",
    "                prefix_word = \"\"\n",
    "            else:\n",
    "                suffix_word = w[:-i]\n",
    "                prefix_word = w[word_length - i :]\n",
    "\n",
    "            # if prefix word in corpus then fix. get similar suffix\n",
    "            # if prefix word is not in corpus then fix similar at 3 chars and do same for suffix\n",
    "            if test_word_in_corpus(suffix_word, english_words):\n",
    "                suf = suffix_word\n",
    "                pref = replace_with_similar_word(prefix_word, english_words)\n",
    "                break\n",
    "            if test_word_in_corpus(prefix_word, english_words) & len(prefix_word) > 3:\n",
    "                pref = prefix_word\n",
    "                suf = replace_with_similar_word(suffix_word, english_words)\n",
    "                break\n",
    "            if len(suffix_word) == 3:\n",
    "                suf = replace_with_similar_word(suffix_word, english_words)\n",
    "                pref = replace_with_similar_word(prefix_word, english_words)\n",
    "                break\n",
    "\n",
    "        if (len(suf) >= 5) | (len(pref) >= 5):\n",
    "            suf = suf + \" \"\n",
    "        new_words.append(suf + pref)\n",
    "    return new_words\n",
    "\n",
    "\n",
    "def remove_invalid_words(words):\n",
    "    keep_words = []\n",
    "    invalid_chars = set(\n",
    "        [\n",
    "            \"\\xa0\",\n",
    "            \"¡\",\n",
    "            \"¢\",\n",
    "            \"¨\",\n",
    "            \"¬\",\n",
    "            \"®\",\n",
    "            \"²\",\n",
    "            \"´\",\n",
    "            \"à\",\n",
    "            \"â\",\n",
    "            \"ã\",\n",
    "            \"è\",\n",
    "            \"ê\",\n",
    "            \"ì\",\n",
    "            \"î\",\n",
    "            \"ò\",\n",
    "            \"ô\",\n",
    "            \"ù\",\n",
    "            \"û\",\n",
    "            \"ˆ\",\n",
    "            \"™\",\n",
    "            \"!\",\n",
    "            \"&\",\n",
    "            \"'\",\n",
    "            \"(\",\n",
    "            \")\",\n",
    "            \",\",\n",
    "            \"-\",\n",
    "            \".\",\n",
    "            \"/\",\n",
    "            \":\",\n",
    "        ]\n",
    "    )\n",
    "    for w in tqdm(words):\n",
    "        if len(set(w).intersection(invalid_chars)) == 0:\n",
    "            keep_words.append(w)\n",
    "    return keep_words\n",
    "\n",
    "\n",
    "def generate_word_list_df(df, col):\n",
    "    \"\"\"creates a list of lowercase words with the new-line \\n character appended\"\"\"\n",
    "    words = list(df[col])\n",
    "    words = [word.lower() + \"\\n\" for word in words]\n",
    "    return words\n",
    "\n",
    "\n",
    "def generate_word_list(filename):\n",
    "    data = open(filename, \"r\").read()\n",
    "    data = data.lower()\n",
    "    words = data.split(\"\\n\")\n",
    "    words = [word + \"\\n\" for word in words]\n",
    "    return words\n",
    "\n",
    "\n",
    "def get_vocab(word_list):\n",
    "    vocab = sorted(set(\" \".join(word_list)))\n",
    "    return vocab\n",
    "\n",
    "\n",
    "def get_vocab_size_and_dicts(vocab):\n",
    "    vocab_size = len(vocab)\n",
    "    char_to_idx = {char: idx for idx, char in enumerate(vocab)}\n",
    "    idx_to_char = {idx: char for idx, char in enumerate(vocab)}\n",
    "    return vocab_size, char_to_idx, idx_to_char\n",
    "\n",
    "\n",
    "def generate_training_data(words):\n",
    "    \"\"\"creates sequences of characters for training, with padded zeros, and associated dicionaries for creating the right input\n",
    "    shapes in the model\n",
    "\n",
    "    returns: sequences_encoded: sequence matrix encoded into numbers based on idx_to_char\n",
    "    y_encoded : the next character based on the previous sequences_encoded\n",
    "    vocab: the list of characters in the training data\n",
    "    vocab_size: the length of vocab\n",
    "    char_to_idx : dictionary of character to number so you and encode a word\n",
    "    idx_to_char : dictionary of number to char so you can decode a sequence\"\"\"\n",
    "    vocab = get_vocab(words)\n",
    "    vocab_size, char_to_idx, idx_to_char = get_vocab_size_and_dicts(vocab)\n",
    "    max_len = max(map(len, words))\n",
    "    sequences = []\n",
    "    y_sequence = []\n",
    "    for word in tqdm(words):\n",
    "        for i in range(len(word) - 1):\n",
    "            sequence = word[: i + 1]\n",
    "            sequences.append([char_to_idx[char] for char in sequence])\n",
    "            if i < len(word) - 1:\n",
    "                y_sequence.append(char_to_idx[word[i + 1]])\n",
    "            else:\n",
    "                pass\n",
    "\n",
    "    sequences_padded = tf.keras.preprocessing.sequence.pad_sequences(\n",
    "        sequences, maxlen=max_len, padding=\"post\"\n",
    "    )\n",
    "\n",
    "    sequences_encoded = tf.keras.utils.to_categorical(\n",
    "        sequences_padded, num_classes=vocab_size\n",
    "    )\n",
    "    for seq in tqdm(sequences_encoded):\n",
    "        for a in seq:\n",
    "            a[\n",
    "                0\n",
    "            ] = 0  # to prevent the padding being turned into the character index '0', which is newline\n",
    "    y_encoded = tf.keras.utils.to_categorical(y_sequence, num_classes=vocab_size)\n",
    "    return (\n",
    "        sequences_encoded,\n",
    "        y_encoded,\n",
    "        vocab,\n",
    "        vocab_size,\n",
    "        char_to_idx,\n",
    "        idx_to_char,\n",
    "        max_len,\n",
    "    )\n",
    "\n",
    "\n",
    "# Define the learning rate decay function\n",
    "def learning_rate_decay(initial_learning_rate, decay_steps, decay_rate):\n",
    "    global_step = tf.Variable(0, trainable=False)\n",
    "    learning_rate = ExponentialDecay(\n",
    "        initial_learning_rate, decay_steps, decay_rate, staircase=True\n",
    "    )\n",
    "    return learning_rate\n",
    "\n",
    "\n",
    "def get_learning_rate(x, batch_size, initial_learning_rate=0.001, decay_rate=0.96):\n",
    "    batches_per_epoch = len(x) / batch_size\n",
    "    print(\"batche_per_epoch \", batches_per_epoch)\n",
    "    change_lr_epochs = 50\n",
    "    decay_steps = batches_per_epoch * change_lr_epochs\n",
    "\n",
    "    learning_rate = learning_rate_decay(initial_learning_rate, decay_steps, decay_rate)\n",
    "    return learning_rate\n",
    "\n",
    "\n",
    "def build_model_rnn(size, x, batch_size, max_len, vocab_size):\n",
    "    # reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(monitor='loss', factor=0.96, patience=5, min_lr=0.0000001)\n",
    "    optimizer = Adam(clipvalue=5, learning_rate=get_learning_rate(x, batch_size))\n",
    "    model_rnn = Sequential()\n",
    "    model_rnn.add(SimpleRNN(size, input_shape=(max_len, vocab_size), dropout=0.0))\n",
    "    model_rnn.add(Dense(vocab_size, activation=\"softmax\"))\n",
    "    model_rnn.compile(loss=\"categorical_crossentropy\", optimizer=optimizer)\n",
    "\n",
    "    return model_rnn\n",
    "\n",
    "\n",
    "def build_model_LSTM(size, x, batch_size, max_len, vocab_size, learning_rate=None):\n",
    "    if learning_rate is not None:\n",
    "        optimizer = Adam(clipvalue=5, learning_rate=learning_rate)\n",
    "    else:\n",
    "        optimizer = Adam(\n",
    "            clipnorm=1, learning_rate=get_learning_rate(x, batch_size)\n",
    "        )  # was clipvalue=5\n",
    "    model = Sequential()\n",
    "    model.add(LSTM(size, input_shape=(max_len, vocab_size)))\n",
    "    model.add(Dense(vocab_size, activation=\"softmax\"))\n",
    "    model.compile(loss=\"categorical_crossentropy\", optimizer=optimizer)\n",
    "    return model\n",
    "\n",
    "\n",
    "def build_model_LSTM_tune(hp, max_len, voacb_size):\n",
    "    \"\"\"this creates a hyperparamter tuneable model, where we set the range of values to explore in\n",
    "    size of model (units); learning_rate; dropout, clipvalue.\"\"\"\n",
    "\n",
    "    hp_units = hp.Int(\"units\", min_value=24, max_value=124, step=10)\n",
    "    hp_learning_rate = hp.Float(\"learning_rate\", 1e-7, 1e-2, sampling=\"log\")\n",
    "    hp_dropout = hp.Float(\"dropout\", min_value=0, max_value=0.5, step=0.1)\n",
    "    hp_clipvalue = hp.Int(\"clipvalue\", min_value=0, max_value=5, step=1)\n",
    "    optimizer = Adam(clipvalue=hp_clipvalue, learning_rate=hp_learning_rate)\n",
    "    model = Sequential()\n",
    "    model.add(\n",
    "        LSTM(units=hp_units, input_shape=(max_len, vocab_size), dropout=hp_dropout)\n",
    "    )\n",
    "    model.add(Dense(vocab_size, activation=\"softmax\"))\n",
    "    model.compile(loss=\"categorical_crossentropy\", optimizer=optimizer)\n",
    "    return model\n",
    "\n",
    "\n",
    "def generate_words(\n",
    "    model,\n",
    "    vocab_size,\n",
    "    max_len,\n",
    "    idx_to_char,\n",
    "    char_to_idx,\n",
    "    number=20,\n",
    "    temperature=1,\n",
    "    seed_word=None,\n",
    "):\n",
    "    \"\"\"takes the model and generates words based on softmax output for each character, it will run through the model for\n",
    "    every character in the sequence and randomly sample from the character probabilities (not the max probability) this means\n",
    "    we get variable words each time\"\"\"\n",
    "    seed_word_original = seed_word\n",
    "\n",
    "    def generate_word(seed_word, i=0):\n",
    "        def adjust_temperature(predictions, temperature):\n",
    "            predictions = np.log(predictions) / temperature\n",
    "            exp_preds = np.exp(predictions)\n",
    "            adjusted_preds = exp_preds / np.sum(exp_preds)\n",
    "            return adjusted_preds\n",
    "\n",
    "        def next_char(preds):\n",
    "            next_idx = np.random.choice(range(vocab_size), p=preds.ravel())\n",
    "            # next_idx = np.argmax(preds)\n",
    "            char = idx_to_char[next_idx]\n",
    "            return char\n",
    "\n",
    "        def word_to_input(word: str):\n",
    "            \"\"\"takes a string and turns it into a sequence matrix\"\"\"\n",
    "            x_pred = np.zeros((1, max_len, vocab_size))\n",
    "            for t, char in enumerate(word):\n",
    "                x_pred[0, t, char_to_idx[char]] = 1.0\n",
    "            return x_pred\n",
    "\n",
    "        if len(seed_word) == max_len:\n",
    "            return seed_word\n",
    "\n",
    "        x_input = word_to_input(seed_word)\n",
    "        preds = model.predict(x_input, verbose=False)\n",
    "        if temperature != 1:\n",
    "            preds = adjust_temperature(preds, temperature)\n",
    "        char = next_char(preds)\n",
    "        i += 1\n",
    "        # print(seed_word, char, i)\n",
    "\n",
    "        if char == \"\\n\":\n",
    "            return seed_word\n",
    "        else:\n",
    "            return generate_word(seed_word + char, i)\n",
    "\n",
    "    output = []\n",
    "    print(\"generating words\")\n",
    "    for i in range(number):\n",
    "        if seed_word is None:\n",
    "            seed_word = idx_to_char[np.random.choice(np.arange(2, len(char_to_idx)))]\n",
    "        word = generate_word(seed_word)\n",
    "        output.append(word)\n",
    "        seed_word = seed_word_original\n",
    "    return output\n",
    "\n",
    "\n",
    "def run_model(model, x, y, epochs, size, type=\"rnn\", verbose=1, batch_size=1024):\n",
    "    model.fit(\n",
    "        x,\n",
    "        y,\n",
    "        batch_size=batch_size,\n",
    "        epochs=epochs,\n",
    "        verbose=verbose,\n",
    "        callbacks=[tensorboard_callback],\n",
    "    )\n",
    "    model.save(f\"{type}_{size}.keras\")\n",
    "    return model.history.history[\"loss\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d5832d46-281d-4c7e-a4b0-e00bfd2fd614",
   "metadata": {},
   "source": [
    "# Get and create training data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "9f28d4d3-c28a-4ee6-9a60-32ef72f0b331",
   "metadata": {},
   "outputs": [],
   "source": [
    "# get list of town wordsa\n",
    "df = pd.read_excel(\"placenames.xlsx\")\n",
    "\n",
    "\n",
    "df[\"length\"] = df[\"place_name\"].map(len)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "ef1c0151-ed29-4e98-9b56-c4715bc805da",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>place_name</th>\n",
       "      <th>country</th>\n",
       "      <th>length</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>54017</th>\n",
       "      <td>Lands common to Hamsterley, Lynesack and Softl...</td>\n",
       "      <td>England</td>\n",
       "      <td>66</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>106176</th>\n",
       "      <td>Wolsingham Park Moor lands cmn to Stanhope, To...</td>\n",
       "      <td>England</td>\n",
       "      <td>66</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>54018</th>\n",
       "      <td>Lands common to Holme Abbey, Holme Low and Hol...</td>\n",
       "      <td>England</td>\n",
       "      <td>60</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>54016</th>\n",
       "      <td>Lands common to Fylingdales and Hawsker-cum-St...</td>\n",
       "      <td>England</td>\n",
       "      <td>54</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16228</th>\n",
       "      <td>Cadeby, Carlton and Market Bosworth with Shack...</td>\n",
       "      <td>England</td>\n",
       "      <td>53</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                               place_name  country  length\n",
       "54017   Lands common to Hamsterley, Lynesack and Softl...  England      66\n",
       "106176  Wolsingham Park Moor lands cmn to Stanhope, To...  England      66\n",
       "54018   Lands common to Holme Abbey, Holme Low and Hol...  England      60\n",
       "54016   Lands common to Fylingdales and Hawsker-cum-St...  England      54\n",
       "16228   Cadeby, Carlton and Market Bosworth with Shack...  England      53"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.sort_values(by=\"length\", ascending=False).head(5)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "adaae5a9-8554-4545-bdc9-b00d39f290b5",
   "metadata": {},
   "source": [
    "We will only use place names fewer than 25 characters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "092af025-943e-4db4-977f-2b62d1a1d97c",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = df.drop_duplicates().query(\"country == 'England' & length < 25\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "92778159-314c-43b8-b6fa-73bc52019493",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>place_name</th>\n",
       "      <th>country</th>\n",
       "      <th>length</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>47950</th>\n",
       "      <td>Hoo</td>\n",
       "      <td>England</td>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>516</th>\n",
       "      <td>Aby</td>\n",
       "      <td>England</td>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>70461</th>\n",
       "      <td>Nox</td>\n",
       "      <td>England</td>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>70847</th>\n",
       "      <td>Oby</td>\n",
       "      <td>England</td>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16719</th>\n",
       "      <td>Cam</td>\n",
       "      <td>England</td>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>105995</th>\n",
       "      <td>Witney North and East ED</td>\n",
       "      <td>England</td>\n",
       "      <td>24</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>40059</th>\n",
       "      <td>Great and Little Hampden</td>\n",
       "      <td>England</td>\n",
       "      <td>24</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>51203</th>\n",
       "      <td>Kensington and Fairfield</td>\n",
       "      <td>England</td>\n",
       "      <td>24</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>61007</th>\n",
       "      <td>Lymington and Pennington</td>\n",
       "      <td>England</td>\n",
       "      <td>24</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>63345</th>\n",
       "      <td>Meonstoke and Corhampton</td>\n",
       "      <td>England</td>\n",
       "      <td>24</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>49133 rows × 3 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "                      place_name  country  length\n",
       "47950                        Hoo  England       3\n",
       "516                          Aby  England       3\n",
       "70461                        Nox  England       3\n",
       "70847                        Oby  England       3\n",
       "16719                        Cam  England       3\n",
       "...                          ...      ...     ...\n",
       "105995  Witney North and East ED  England      24\n",
       "40059   Great and Little Hampden  England      24\n",
       "51203   Kensington and Fairfield  England      24\n",
       "61007   Lymington and Pennington  England      24\n",
       "63345   Meonstoke and Corhampton  England      24\n",
       "\n",
       "[49133 rows x 3 columns]"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.sort_values(by=\"length\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "5cd790a5-967c-431e-9a69-19a54c8d331c",
   "metadata": {
    "editable": true,
    "slideshow": {
     "slide_type": ""
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████| 49133/49133 [00:00<00:00, 52373.17it/s]\n",
      "100%|████████████████████████████████████████████████████████████████| 570251/570251 [00:02<00:00, 236218.66it/s]\n"
     ]
    }
   ],
   "source": [
    "places = generate_word_list_df(df, \"place_name\")\n",
    "random.shuffle(places)  # other wise they are alphabetical\n",
    "# some of place names are too long and contain lots of parentheses, we will only use places with <25 characters from England\n",
    "# places2 = remove_invalid_words(places)\n",
    "x, y, vocab, vocab_size, char_to_idx, idx_to_char, max_len = generate_training_data(\n",
    "    places\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "e68645d5-16dd-4a63-8917-0741ce2048c0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Selection of characters:  ['\\n', ' ', '!', '&', \"'\", '(', ')', ',', '-', '.', '/', ':', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h']\n"
     ]
    }
   ],
   "source": [
    "print(\"Selection of characters: \", vocab[0:20])  # vocab lists all characters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "ea0dd79c-6f6e-49d5-9820-4720098381c5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "There are  38  different characters to encode including end of line \\n\n"
     ]
    }
   ],
   "source": [
    "print(\n",
    "    \"There are \",\n",
    "    vocab_size,\n",
    "    \" different characters to encode including end of line \\\\n\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "cbaa4022-6b15-4bf8-8bdd-9e53d6c2c054",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "we encode each character:  {'\\n': 0, ' ': 1, '!': 2, '&': 3, \"'\": 4, '(': 5, ')': 6, ',': 7, '-': 8, '.': 9, '/': 10, ':': 11, 'a': 12, 'b': 13, 'c': 14, 'd': 15, 'e': 16, 'f': 17, 'g': 18, 'h': 19, 'i': 20, 'j': 21, 'k': 22, 'l': 23, 'm': 24, 'n': 25, 'o': 26, 'p': 27, 'q': 28, 'r': 29, 's': 30, 't': 31, 'u': 32, 'v': 33, 'w': 34, 'x': 35, 'y': 36, 'z': 37} \n",
      " and decode using:  {0: '\\n', 1: ' ', 2: '!', 3: '&', 4: \"'\", 5: '(', 6: ')', 7: ',', 8: '-', 9: '.', 10: '/', 11: ':', 12: 'a', 13: 'b', 14: 'c', 15: 'd', 16: 'e', 17: 'f', 18: 'g', 19: 'h', 20: 'i', 21: 'j', 22: 'k', 23: 'l', 24: 'm', 25: 'n', 26: 'o', 27: 'p', 28: 'q', 29: 'r', 30: 's', 31: 't', 32: 'u', 33: 'v', 34: 'w', 35: 'x', 36: 'y', 37: 'z'}\n"
     ]
    }
   ],
   "source": [
    "print(\"we encode each character: \", char_to_idx, \"\\n and decode using: \", idx_to_char)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "f26bb89f-0438-498b-875d-415c986a76c7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "the longest word we will create is based on the longest word in the training data which is:  25\n"
     ]
    }
   ],
   "source": [
    "print(\n",
    "    \"the longest word we will create is based on the longest word in the training data which is: \",\n",
    "    max_len,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "198ba250-ed94-4d7e-a6fa-6131276fd055",
   "metadata": {},
   "source": [
    "## Convert data to tensors\n",
    "Tensorflow handles batches etc better in the custom format"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "ada8ff39-3be6-4683-a23e-e3e2c9ec6a8a",
   "metadata": {
    "editable": true,
    "slideshow": {
     "slide_type": ""
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2023-09-09 18:49:03.754599: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:982] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node\n",
      "Your kernel may have been built without NUMA support.\n",
      "2023-09-09 18:49:03.754820: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:982] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node\n",
      "Your kernel may have been built without NUMA support.\n",
      "2023-09-09 18:49:03.754899: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:982] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node\n",
      "Your kernel may have been built without NUMA support.\n",
      "2023-09-09 18:49:05.481597: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:982] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node\n",
      "Your kernel may have been built without NUMA support.\n",
      "2023-09-09 18:49:05.481686: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:982] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node\n",
      "Your kernel may have been built without NUMA support.\n",
      "2023-09-09 18:49:05.481696: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1722] Could not identify NUMA node of platform GPU id 0, defaulting to 0.  Your kernel may not have been built with NUMA support.\n",
      "2023-09-09 18:49:05.481736: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:982] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node\n",
      "Your kernel may have been built without NUMA support.\n",
      "2023-09-09 18:49:05.481764: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1635] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 5897 MB memory:  -> device: 0, name: NVIDIA GeForce RTX 2070, pci bus id: 0000:01:00.0, compute capability: 7.5\n",
      "2023-09-09 18:49:05.487757: W tensorflow/tsl/framework/cpu_allocator_impl.cc:83] Allocation of 2166953800 exceeds 10% of free system memory.\n"
     ]
    }
   ],
   "source": [
    "x_tensor = tf.convert_to_tensor(x)\n",
    "y_tensor = tf.convert_to_tensor(y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "5f0e526e-ce51-44f7-b25f-1c4f20866d10",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "25"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "max_len"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "1569b07f-1940-4149-b182-8667c71dabc6",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2023-09-09 18:49:09.447314: W tensorflow/tsl/framework/cpu_allocator_impl.cc:83] Allocation of 2166953800 exceeds 10% of free system memory.\n"
     ]
    }
   ],
   "source": [
    "batch_size = 32\n",
    "train_dataset = tf.data.Dataset.from_tensor_slices((x_tensor, y_tensor)).batch(\n",
    "    batch_size\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "114e7211-d29a-4568-bb11-6c7a298533c5",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<_BatchDataset element_spec=(TensorSpec(shape=(None, 25, 38), dtype=tf.float32, name=None), TensorSpec(shape=(None, 38), dtype=tf.float32, name=None))>"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "01ddd4aa-2c76-4086-bc4b-2e2eac528ca4",
   "metadata": {},
   "source": [
    "## Create a hyper model to explore model parameters\n",
    "This allows keras to do all the fine tuning for you"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "5227f705-473f-4ba6-98be-28c5316d8ecd",
   "metadata": {},
   "outputs": [],
   "source": [
    "class MyHyperModel(kt.HyperModel):\n",
    "    def __init__(self, max_len, vocab_size):\n",
    "        self.max_len = max_len\n",
    "        self.vocab_size = vocab_size\n",
    "\n",
    "    def build(self, hp):\n",
    "        hp_units = hp.Int(\"units\", min_value=24, max_value=124, step=10)\n",
    "        hp_learning_rate = hp.Float(\"learning_rate\", 1e-7, 1e-2, sampling=\"log\")\n",
    "        hp_dropout = hp.Float(\"dropout\", min_value=0, max_value=0.5, step=0.1)\n",
    "        hp_clipvalue = hp.Int(\"clipvalue\", min_value=0, max_value=5, step=1)\n",
    "        optimizer = Adam(clipvalue=hp_clipvalue, learning_rate=hp_learning_rate)\n",
    "        model = Sequential()\n",
    "        model.add(\n",
    "            LSTM(\n",
    "                units=hp_units,\n",
    "                input_shape=(self.max_len, self.vocab_size),\n",
    "                dropout=hp_dropout,\n",
    "            )\n",
    "        )\n",
    "        model.add(Dense(self.vocab_size, activation=\"softmax\"))\n",
    "        model.compile(loss=\"categorical_crossentropy\", optimizer=optimizer)\n",
    "        return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "036dc075-e253-4d39-a313-b75fb935d5da",
   "metadata": {},
   "outputs": [],
   "source": [
    "hyper_model = MyHyperModel(max_len, vocab_size)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "395e8f35-befb-432f-aec5-907732d9b30d",
   "metadata": {},
   "source": [
    "### Do fine tuning\n",
    "To select best model parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "59aa03fc-e7af-4d5c-9106-dab3b1df0619",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Reloading Tuner from my_tuner_dir/tuning_english_placenames/tuner0.json\n"
     ]
    }
   ],
   "source": [
    "# do tuning\n",
    "tuner = kt.Hyperband(\n",
    "    hyper_model,\n",
    "    objective=\"loss\",\n",
    "    max_epochs=100,\n",
    "    factor=3,\n",
    "    directory=\"my_tuner_dir\",\n",
    "    project_name=\"tuning_english_placenames\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e335045c-d67f-4ed5-b6f4-3d2d2d6d27bd",
   "metadata": {},
   "outputs": [],
   "source": [
    "tuner.search(x_tensor, y_tensor, epochs=10, batch_size=512)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "10d04dd2-833e-4289-b482-9021fae5bdb3",
   "metadata": {},
   "source": [
    "### build model \n",
    "based on best hyperparameter selection"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "40d82dd4-20d4-4a04-ac33-6f27c25b31b7",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2023-09-09 18:50:53.114618: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'gradients/split_2_grad/concat/split_2/split_dim' with dtype int32\n",
      "\t [[{{node gradients/split_2_grad/concat/split_2/split_dim}}]]\n",
      "2023-09-09 18:50:53.117850: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'gradients/split_grad/concat/split/split_dim' with dtype int32\n",
      "\t [[{{node gradients/split_grad/concat/split/split_dim}}]]\n",
      "2023-09-09 18:50:53.119767: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'gradients/split_1_grad/concat/split_1/split_dim' with dtype int32\n",
      "\t [[{{node gradients/split_1_grad/concat/split_1/split_dim}}]]\n"
     ]
    }
   ],
   "source": [
    "# Retrieve the best hyperparameters\n",
    "best_hps = tuner.get_best_hyperparameters(num_trials=1)[0]\n",
    "\n",
    "# Build the best model with the best hyperparameters\n",
    "best_model = tuner.hypermodel.build(best_hps)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "70ce6235-a230-43e3-ab38-5b1d542a8ade",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "The hyperparameter search is complete. The optimal number of units in the first densely-connected\n",
      "layer is 114 and the optimal learning rate for the optimizer\n",
      "is 0.003415073321501355, dropout is 0.0, clipvalue is 0.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(\n",
    "    f\"\"\"\n",
    "The hyperparameter search is complete. The optimal number of units in the first densely-connected\n",
    "layer is {best_hps.get('units')} and the optimal learning rate for the optimizer\n",
    "is {best_hps.get('learning_rate')}, dropout is {best_hps.get('dropout')}, clipvalue is {best_hps.get('clipvalue')}.\n",
    "\"\"\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f61dbd3e-6041-468e-b4a7-e2c74769df94",
   "metadata": {},
   "source": [
    "### Add learning rate decay\n",
    "So model learns optimally, then compile the model with the optimizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "1093dfb6-c284-4265-af65-ee59f76ca70f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Learning rate decay function\n",
    "def learning_rate_decay(epoch, initial_lr):\n",
    "    # You can define any learning rate decay function here\n",
    "    decay_rate = 0.95\n",
    "    decay_step = 10\n",
    "    new_lr = initial_lr * (decay_rate ** (epoch // decay_step))\n",
    "    return new_lr\n",
    "\n",
    "\n",
    "# Initialize your model\n",
    "# Set the initial learning rate for the optimizer\n",
    "initial_lr = best_hps.get(\"learning_rate\")\n",
    "clipvalue = best_hps.get(\"clipvalue\")\n",
    "dropout = best_hps.get(\"dropout\")\n",
    "\n",
    "# Define the LearningRateScheduler callback\n",
    "lr_decay_callback = LearningRateScheduler(\n",
    "    lambda epoch: learning_rate_decay(epoch, initial_lr)\n",
    ")\n",
    "# Fit your model with the LearningRateScheduler callback\n",
    "optimizer = Adam(clipvalue=clipvalue, learning_rate=initial_lr)\n",
    "best_model.compile(loss=\"categorical_crossentropy\", optimizer=optimizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "eeda8304-89a0-4a47-8b8b-4242b0a0aa75",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"sequential\"\n",
      "_________________________________________________________________\n",
      " Layer (type)                Output Shape              Param #   \n",
      "=================================================================\n",
      " lstm (LSTM)                 (None, 114)               69768     \n",
      "                                                                 \n",
      " dense (Dense)               (None, 38)                4370      \n",
      "                                                                 \n",
      "=================================================================\n",
      "Total params: 74,138\n",
      "Trainable params: 74,138\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "best_model.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "34cbbeb5-0cd0-48e3-8596-515afc82ed25",
   "metadata": {},
   "source": [
    "### Train the model\n",
    "best to use tensor Dataset class as it handles batch sizes"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b6565982-6955-4932-a74f-c7b80d8bcc06",
   "metadata": {},
   "source": [
    "### add checkpoints\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "df6ca631-a954-4b4d-9b7c-ba36adae0fe4",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tensorflow.keras.callbacks import ModelCheckpoint\n",
    "\n",
    "# Define a callback to save model checkpoints\n",
    "checkpoint_callback = ModelCheckpoint(\"model_checkpoint.h5\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "197d08e5-e9ac-46bd-b50a-127ebbf9aa2d",
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 512\n",
    "train_dataset = tf.data.Dataset.from_tensor_slices((x_tensor, y_tensor)).batch(\n",
    "    batch_size\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "181e008d-e87d-48d0-b42a-36ca9bfe2124",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.4038 - lr: 0.0034\n",
      "Epoch 2/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.3970 - lr: 0.0034\n",
      "Epoch 3/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.3867 - lr: 0.0034\n",
      "Epoch 4/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.3789 - lr: 0.0034\n",
      "Epoch 5/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.3714 - lr: 0.0034\n",
      "Epoch 6/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.3669 - lr: 0.0034\n",
      "Epoch 7/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.3595 - lr: 0.0034\n",
      "Epoch 8/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.3547 - lr: 0.0034\n",
      "Epoch 9/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.3506 - lr: 0.0034\n",
      "Epoch 10/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.3439 - lr: 0.0034\n",
      "Epoch 11/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.3378 - lr: 0.0032\n",
      "Epoch 12/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.3354 - lr: 0.0032\n",
      "Epoch 13/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.3368 - lr: 0.0032\n",
      "Epoch 14/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.3253 - lr: 0.0032\n",
      "Epoch 15/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.3226 - lr: 0.0032\n",
      "Epoch 16/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.3179 - lr: 0.0032\n",
      "Epoch 17/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.3189 - lr: 0.0032\n",
      "Epoch 18/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.3145 - lr: 0.0032\n",
      "Epoch 19/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.3488 - lr: 0.0032\n",
      "Epoch 20/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.3151 - lr: 0.0032\n",
      "Epoch 21/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.3249 - lr: 0.0031\n",
      "Epoch 22/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.3057 - lr: 0.0031\n",
      "Epoch 23/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.3024 - lr: 0.0031\n",
      "Epoch 24/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.3022 - lr: 0.0031\n",
      "Epoch 25/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.2994 - lr: 0.0031\n",
      "Epoch 26/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.2974 - lr: 0.0031\n",
      "Epoch 27/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.2931 - lr: 0.0031\n",
      "Epoch 28/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.2985 - lr: 0.0031\n",
      "Epoch 29/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.2896 - lr: 0.0031\n",
      "Epoch 30/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.2892 - lr: 0.0031\n",
      "Epoch 31/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.2854 - lr: 0.0029\n",
      "Epoch 32/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.2857 - lr: 0.0029\n",
      "Epoch 33/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.2798 - lr: 0.0029\n",
      "Epoch 34/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.2794 - lr: 0.0029\n",
      "Epoch 35/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.2762 - lr: 0.0029\n",
      "Epoch 36/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.2772 - lr: 0.0029\n",
      "Epoch 37/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.2768 - lr: 0.0029\n",
      "Epoch 38/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.2750 - lr: 0.0029\n",
      "Epoch 39/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.2730 - lr: 0.0029\n",
      "Epoch 40/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.2725 - lr: 0.0029\n",
      "Epoch 41/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.2704 - lr: 0.0028\n",
      "Epoch 42/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.2669 - lr: 0.0028\n",
      "Epoch 43/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.2658 - lr: 0.0028\n",
      "Epoch 44/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.2646 - lr: 0.0028\n",
      "Epoch 45/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.2659 - lr: 0.0028\n",
      "Epoch 46/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.2633 - lr: 0.0028\n",
      "Epoch 47/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.2619 - lr: 0.0028\n",
      "Epoch 48/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.3544 - lr: 0.0028\n",
      "Epoch 49/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.3018 - lr: 0.0028\n",
      "Epoch 50/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.2784 - lr: 0.0028\n",
      "Epoch 51/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.2717 - lr: 0.0026\n",
      "Epoch 52/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.2656 - lr: 0.0026\n",
      "Epoch 53/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.2632 - lr: 0.0026\n",
      "Epoch 54/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.2690 - lr: 0.0026\n",
      "Epoch 55/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.2716 - lr: 0.0026\n",
      "Epoch 56/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.2614 - lr: 0.0026\n",
      "Epoch 57/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.2591 - lr: 0.0026\n",
      "Epoch 58/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.2588 - lr: 0.0026\n",
      "Epoch 59/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.2579 - lr: 0.0026\n",
      "Epoch 60/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.2565 - lr: 0.0026\n",
      "Epoch 61/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.2556 - lr: 0.0025\n",
      "Epoch 62/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.2467 - lr: 0.0025\n",
      "Epoch 63/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.2464 - lr: 0.0025\n",
      "Epoch 64/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.2454 - lr: 0.0025\n",
      "Epoch 65/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.2528 - lr: 0.0025\n",
      "Epoch 66/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.2504 - lr: 0.0025\n",
      "Epoch 67/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.2503 - lr: 0.0025\n",
      "Epoch 68/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.2448 - lr: 0.0025\n",
      "Epoch 69/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.2473 - lr: 0.0025\n",
      "Epoch 70/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.2446 - lr: 0.0025\n",
      "Epoch 71/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.2419 - lr: 0.0024\n",
      "Epoch 72/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.2369 - lr: 0.0024\n",
      "Epoch 73/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.2369 - lr: 0.0024\n",
      "Epoch 74/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.2392 - lr: 0.0024\n",
      "Epoch 75/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.2375 - lr: 0.0024\n",
      "Epoch 76/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.2380 - lr: 0.0024\n",
      "Epoch 77/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.2359 - lr: 0.0024\n",
      "Epoch 78/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.2339 - lr: 0.0024\n",
      "Epoch 79/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.2346 - lr: 0.0024\n",
      "Epoch 80/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.2378 - lr: 0.0024\n",
      "Epoch 81/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.2469 - lr: 0.0023\n",
      "Epoch 82/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.2322 - lr: 0.0023\n",
      "Epoch 83/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.2292 - lr: 0.0023\n",
      "Epoch 84/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.2305 - lr: 0.0023\n",
      "Epoch 85/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.2308 - lr: 0.0023\n",
      "Epoch 86/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.2334 - lr: 0.0023\n",
      "Epoch 87/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.2323 - lr: 0.0023\n",
      "Epoch 88/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.2960 - lr: 0.0023\n",
      "Epoch 89/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.2571 - lr: 0.0023\n",
      "Epoch 90/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.2301 - lr: 0.0023\n",
      "Epoch 91/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.2425 - lr: 0.0022\n",
      "Epoch 92/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.2273 - lr: 0.0022\n",
      "Epoch 93/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.2232 - lr: 0.0022\n",
      "Epoch 94/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.2241 - lr: 0.0022\n",
      "Epoch 95/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.2220 - lr: 0.0022\n",
      "Epoch 96/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.2231 - lr: 0.0022\n",
      "Epoch 97/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.2256 - lr: 0.0022\n",
      "Epoch 98/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.2299 - lr: 0.0022\n",
      "Epoch 99/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.2228 - lr: 0.0022\n",
      "Epoch 100/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.2222 - lr: 0.0022\n",
      "Epoch 101/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.2245 - lr: 0.0020\n",
      "Epoch 102/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.2175 - lr: 0.0020\n",
      "Epoch 103/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.2132 - lr: 0.0020\n",
      "Epoch 104/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.2226 - lr: 0.0020\n",
      "Epoch 105/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.2166 - lr: 0.0020\n",
      "Epoch 106/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.2205 - lr: 0.0020\n",
      "Epoch 107/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.2135 - lr: 0.0020\n",
      "Epoch 108/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.2125 - lr: 0.0020\n",
      "Epoch 109/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.2114 - lr: 0.0020\n",
      "Epoch 110/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.2147 - lr: 0.0020\n",
      "Epoch 111/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.2110 - lr: 0.0019\n",
      "Epoch 112/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.2087 - lr: 0.0019\n",
      "Epoch 113/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.2090 - lr: 0.0019\n",
      "Epoch 114/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.2103 - lr: 0.0019\n",
      "Epoch 115/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.2093 - lr: 0.0019\n",
      "Epoch 116/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.2081 - lr: 0.0019\n",
      "Epoch 117/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.2083 - lr: 0.0019\n",
      "Epoch 118/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.2086 - lr: 0.0019\n",
      "Epoch 119/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.2078 - lr: 0.0019\n",
      "Epoch 120/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.2095 - lr: 0.0019\n",
      "Epoch 121/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.2044 - lr: 0.0018\n",
      "Epoch 122/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.2035 - lr: 0.0018\n",
      "Epoch 123/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.2045 - lr: 0.0018\n",
      "Epoch 124/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.2060 - lr: 0.0018\n",
      "Epoch 125/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.2110 - lr: 0.0018\n",
      "Epoch 126/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.2032 - lr: 0.0018\n",
      "Epoch 127/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1991 - lr: 0.0018\n",
      "Epoch 128/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.2026 - lr: 0.0018\n",
      "Epoch 129/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.2094 - lr: 0.0018\n",
      "Epoch 130/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.2028 - lr: 0.0018\n",
      "Epoch 131/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1991 - lr: 0.0018\n",
      "Epoch 132/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1986 - lr: 0.0018\n",
      "Epoch 133/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1990 - lr: 0.0018\n",
      "Epoch 134/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.2002 - lr: 0.0018\n",
      "Epoch 135/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1982 - lr: 0.0018\n",
      "Epoch 136/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1973 - lr: 0.0018\n",
      "Epoch 137/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1950 - lr: 0.0018\n",
      "Epoch 138/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1962 - lr: 0.0018\n",
      "Epoch 139/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1942 - lr: 0.0018\n",
      "Epoch 140/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1950 - lr: 0.0018\n",
      "Epoch 141/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1948 - lr: 0.0017\n",
      "Epoch 142/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1936 - lr: 0.0017\n",
      "Epoch 143/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1939 - lr: 0.0017\n",
      "Epoch 144/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1909 - lr: 0.0017\n",
      "Epoch 145/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1893 - lr: 0.0017\n",
      "Epoch 146/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1900 - lr: 0.0017\n",
      "Epoch 147/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1909 - lr: 0.0017\n",
      "Epoch 148/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1950 - lr: 0.0017\n",
      "Epoch 149/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1937 - lr: 0.0017\n",
      "Epoch 150/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1898 - lr: 0.0017\n",
      "Epoch 151/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1865 - lr: 0.0016\n",
      "Epoch 152/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1865 - lr: 0.0016\n",
      "Epoch 153/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1862 - lr: 0.0016\n",
      "Epoch 154/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1877 - lr: 0.0016\n",
      "Epoch 155/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1861 - lr: 0.0016\n",
      "Epoch 156/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1845 - lr: 0.0016\n",
      "Epoch 157/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1892 - lr: 0.0016\n",
      "Epoch 158/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1863 - lr: 0.0016\n",
      "Epoch 159/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1847 - lr: 0.0016\n",
      "Epoch 160/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1868 - lr: 0.0016\n",
      "Epoch 161/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1823 - lr: 0.0015\n",
      "Epoch 162/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1825 - lr: 0.0015\n",
      "Epoch 163/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.1817 - lr: 0.0015\n",
      "Epoch 164/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1844 - lr: 0.0015\n",
      "Epoch 165/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1903 - lr: 0.0015\n",
      "Epoch 166/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1831 - lr: 0.0015\n",
      "Epoch 167/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1805 - lr: 0.0015\n",
      "Epoch 168/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1812 - lr: 0.0015\n",
      "Epoch 169/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1818 - lr: 0.0015\n",
      "Epoch 170/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1851 - lr: 0.0015\n",
      "Epoch 171/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1783 - lr: 0.0014\n",
      "Epoch 172/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1782 - lr: 0.0014\n",
      "Epoch 173/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1768 - lr: 0.0014\n",
      "Epoch 174/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1772 - lr: 0.0014\n",
      "Epoch 175/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1755 - lr: 0.0014\n",
      "Epoch 176/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1774 - lr: 0.0014\n",
      "Epoch 177/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1748 - lr: 0.0014\n",
      "Epoch 178/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1748 - lr: 0.0014\n",
      "Epoch 179/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1777 - lr: 0.0014\n",
      "Epoch 180/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1778 - lr: 0.0014\n",
      "Epoch 181/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.1766 - lr: 0.0014\n",
      "Epoch 182/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1779 - lr: 0.0014\n",
      "Epoch 183/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1744 - lr: 0.0014\n",
      "Epoch 184/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1735 - lr: 0.0014\n",
      "Epoch 185/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1747 - lr: 0.0014\n",
      "Epoch 186/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1716 - lr: 0.0014\n",
      "Epoch 187/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1719 - lr: 0.0014\n",
      "Epoch 188/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1720 - lr: 0.0014\n",
      "Epoch 189/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1707 - lr: 0.0014\n",
      "Epoch 190/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1739 - lr: 0.0014\n",
      "Epoch 191/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1692 - lr: 0.0013\n",
      "Epoch 192/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1677 - lr: 0.0013\n",
      "Epoch 193/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1667 - lr: 0.0013\n",
      "Epoch 194/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1665 - lr: 0.0013\n",
      "Epoch 195/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1725 - lr: 0.0013\n",
      "Epoch 196/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1703 - lr: 0.0013\n",
      "Epoch 197/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1676 - lr: 0.0013\n",
      "Epoch 198/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1659 - lr: 0.0013\n",
      "Epoch 199/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1673 - lr: 0.0013\n",
      "Epoch 200/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1680 - lr: 0.0013\n",
      "Epoch 201/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1657 - lr: 0.0012\n",
      "Epoch 202/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1648 - lr: 0.0012\n",
      "Epoch 203/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1664 - lr: 0.0012\n",
      "Epoch 204/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1663 - lr: 0.0012\n",
      "Epoch 205/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1631 - lr: 0.0012\n",
      "Epoch 206/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1673 - lr: 0.0012\n",
      "Epoch 207/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1621 - lr: 0.0012\n",
      "Epoch 208/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1626 - lr: 0.0012\n",
      "Epoch 209/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1663 - lr: 0.0012\n",
      "Epoch 210/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1626 - lr: 0.0012\n",
      "Epoch 211/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1595 - lr: 0.0012\n",
      "Epoch 212/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1607 - lr: 0.0012\n",
      "Epoch 213/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.1638 - lr: 0.0012\n",
      "Epoch 214/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1604 - lr: 0.0012\n",
      "Epoch 215/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1599 - lr: 0.0012\n",
      "Epoch 216/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1589 - lr: 0.0012\n",
      "Epoch 217/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1582 - lr: 0.0012\n",
      "Epoch 218/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1649 - lr: 0.0012\n",
      "Epoch 219/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1609 - lr: 0.0012\n",
      "Epoch 220/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1611 - lr: 0.0012\n",
      "Epoch 221/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.1608 - lr: 0.0011\n",
      "Epoch 222/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1561 - lr: 0.0011\n",
      "Epoch 223/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1553 - lr: 0.0011\n",
      "Epoch 224/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1571 - lr: 0.0011\n",
      "Epoch 225/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.1556 - lr: 0.0011\n",
      "Epoch 226/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.1549 - lr: 0.0011\n",
      "Epoch 227/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1572 - lr: 0.0011\n",
      "Epoch 228/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1572 - lr: 0.0011\n",
      "Epoch 229/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1562 - lr: 0.0011\n",
      "Epoch 230/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1620 - lr: 0.0011\n",
      "Epoch 231/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.1623 - lr: 0.0010\n",
      "Epoch 232/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1547 - lr: 0.0010\n",
      "Epoch 233/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1560 - lr: 0.0010\n",
      "Epoch 234/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1528 - lr: 0.0010\n",
      "Epoch 235/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1509 - lr: 0.0010\n",
      "Epoch 236/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1509 - lr: 0.0010\n",
      "Epoch 237/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1583 - lr: 0.0010\n",
      "Epoch 238/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1569 - lr: 0.0010\n",
      "Epoch 239/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1544 - lr: 0.0010\n",
      "Epoch 240/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.1525 - lr: 0.0010\n",
      "Epoch 241/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1533 - lr: 9.9716e-04\n",
      "Epoch 242/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1514 - lr: 9.9716e-04\n",
      "Epoch 243/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1501 - lr: 9.9716e-04\n",
      "Epoch 244/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1500 - lr: 9.9716e-04\n",
      "Epoch 245/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.1496 - lr: 9.9716e-04\n",
      "Epoch 246/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.1503 - lr: 9.9716e-04\n",
      "Epoch 247/1000\n",
      "1114/1114 [==============================] - 15s 14ms/step - loss: 1.1476 - lr: 9.9716e-04\n",
      "Epoch 248/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1490 - lr: 9.9716e-04\n",
      "Epoch 249/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1475 - lr: 9.9716e-04\n",
      "Epoch 250/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1480 - lr: 9.9716e-04\n",
      "Epoch 251/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1484 - lr: 9.4731e-04\n",
      "Epoch 252/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1459 - lr: 9.4731e-04\n",
      "Epoch 253/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1456 - lr: 9.4731e-04\n",
      "Epoch 254/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1484 - lr: 9.4731e-04\n",
      "Epoch 255/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1516 - lr: 9.4731e-04\n",
      "Epoch 256/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1455 - lr: 9.4731e-04\n",
      "Epoch 257/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1464 - lr: 9.4731e-04\n",
      "Epoch 258/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1453 - lr: 9.4731e-04\n",
      "Epoch 259/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1487 - lr: 9.4731e-04\n",
      "Epoch 260/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1457 - lr: 9.4731e-04\n",
      "Epoch 261/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.1430 - lr: 8.9994e-04\n",
      "Epoch 262/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1436 - lr: 8.9994e-04\n",
      "Epoch 263/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1424 - lr: 8.9994e-04\n",
      "Epoch 264/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.1429 - lr: 8.9994e-04\n",
      "Epoch 265/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1425 - lr: 8.9994e-04\n",
      "Epoch 266/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.1457 - lr: 8.9994e-04\n",
      "Epoch 267/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1430 - lr: 8.9994e-04\n",
      "Epoch 268/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.1415 - lr: 8.9994e-04\n",
      "Epoch 269/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1470 - lr: 8.9994e-04\n",
      "Epoch 270/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.1434 - lr: 8.9994e-04\n",
      "Epoch 271/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1435 - lr: 8.5494e-04\n",
      "Epoch 272/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1413 - lr: 8.5494e-04\n",
      "Epoch 273/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1401 - lr: 8.5494e-04\n",
      "Epoch 274/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1396 - lr: 8.5494e-04\n",
      "Epoch 275/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1398 - lr: 8.5494e-04\n",
      "Epoch 276/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1401 - lr: 8.5494e-04\n",
      "Epoch 277/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1389 - lr: 8.5494e-04\n",
      "Epoch 278/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.1382 - lr: 8.5494e-04\n",
      "Epoch 279/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1414 - lr: 8.5494e-04\n",
      "Epoch 280/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1423 - lr: 8.5494e-04\n",
      "Epoch 281/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.1378 - lr: 8.1220e-04\n",
      "Epoch 282/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.1406 - lr: 8.1220e-04\n",
      "Epoch 283/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1402 - lr: 8.1220e-04\n",
      "Epoch 284/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1378 - lr: 8.1220e-04\n",
      "Epoch 285/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.1365 - lr: 8.1220e-04\n",
      "Epoch 286/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1368 - lr: 8.1220e-04\n",
      "Epoch 287/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1369 - lr: 8.1220e-04\n",
      "Epoch 288/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1388 - lr: 8.1220e-04\n",
      "Epoch 289/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1381 - lr: 8.1220e-04\n",
      "Epoch 290/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.1385 - lr: 8.1220e-04\n",
      "Epoch 291/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1404 - lr: 7.7159e-04\n",
      "Epoch 292/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.1361 - lr: 7.7159e-04\n",
      "Epoch 293/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.1349 - lr: 7.7159e-04\n",
      "Epoch 294/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1335 - lr: 7.7159e-04\n",
      "Epoch 295/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1360 - lr: 7.7159e-04\n",
      "Epoch 296/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1336 - lr: 7.7159e-04\n",
      "Epoch 297/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1329 - lr: 7.7159e-04\n",
      "Epoch 298/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.1371 - lr: 7.7159e-04\n",
      "Epoch 299/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1362 - lr: 7.7159e-04\n",
      "Epoch 300/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.1345 - lr: 7.7159e-04\n",
      "Epoch 301/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1315 - lr: 7.3301e-04\n",
      "Epoch 302/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.1335 - lr: 7.3301e-04\n",
      "Epoch 303/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.1377 - lr: 7.3301e-04\n",
      "Epoch 304/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.1301 - lr: 7.3301e-04\n",
      "Epoch 305/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1294 - lr: 7.3301e-04\n",
      "Epoch 306/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.1302 - lr: 7.3301e-04\n",
      "Epoch 307/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1329 - lr: 7.3301e-04\n",
      "Epoch 308/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1376 - lr: 7.3301e-04\n",
      "Epoch 309/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.1338 - lr: 7.3301e-04\n",
      "Epoch 310/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1374 - lr: 7.3301e-04\n",
      "Epoch 311/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.1300 - lr: 6.9636e-04\n",
      "Epoch 312/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1296 - lr: 6.9636e-04\n",
      "Epoch 313/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1278 - lr: 6.9636e-04\n",
      "Epoch 314/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1308 - lr: 6.9636e-04\n",
      "Epoch 315/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.1275 - lr: 6.9636e-04\n",
      "Epoch 316/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1290 - lr: 6.9636e-04\n",
      "Epoch 317/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.1291 - lr: 6.9636e-04\n",
      "Epoch 318/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1284 - lr: 6.9636e-04\n",
      "Epoch 319/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1304 - lr: 6.9636e-04\n",
      "Epoch 320/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.1293 - lr: 6.9636e-04\n",
      "Epoch 321/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.1271 - lr: 6.6154e-04\n",
      "Epoch 322/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1248 - lr: 6.6154e-04\n",
      "Epoch 323/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1274 - lr: 6.6154e-04\n",
      "Epoch 324/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1275 - lr: 6.6154e-04\n",
      "Epoch 325/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.1258 - lr: 6.6154e-04\n",
      "Epoch 326/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.1271 - lr: 6.6154e-04\n",
      "Epoch 327/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.1268 - lr: 6.6154e-04\n",
      "Epoch 328/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.1284 - lr: 6.6154e-04\n",
      "Epoch 329/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.1250 - lr: 6.6154e-04\n",
      "Epoch 330/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.1259 - lr: 6.6154e-04\n",
      "Epoch 331/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.1245 - lr: 6.2846e-04\n",
      "Epoch 332/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.1245 - lr: 6.2846e-04\n",
      "Epoch 333/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1246 - lr: 6.2846e-04\n",
      "Epoch 334/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.1243 - lr: 6.2846e-04\n",
      "Epoch 335/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1224 - lr: 6.2846e-04\n",
      "Epoch 336/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1217 - lr: 6.2846e-04\n",
      "Epoch 337/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.1231 - lr: 6.2846e-04\n",
      "Epoch 338/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1240 - lr: 6.2846e-04\n",
      "Epoch 339/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.1240 - lr: 6.2846e-04\n",
      "Epoch 340/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.1237 - lr: 6.2846e-04\n",
      "Epoch 341/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1229 - lr: 5.9704e-04\n",
      "Epoch 342/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1250 - lr: 5.9704e-04\n",
      "Epoch 343/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.1222 - lr: 5.9704e-04\n",
      "Epoch 344/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1214 - lr: 5.9704e-04\n",
      "Epoch 345/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1224 - lr: 5.9704e-04\n",
      "Epoch 346/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.1228 - lr: 5.9704e-04\n",
      "Epoch 347/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.1234 - lr: 5.9704e-04\n",
      "Epoch 348/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.1195 - lr: 5.9704e-04\n",
      "Epoch 349/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1218 - lr: 5.9704e-04\n",
      "Epoch 350/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1232 - lr: 5.9704e-04\n",
      "Epoch 351/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1207 - lr: 5.6719e-04\n",
      "Epoch 352/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1178 - lr: 5.6719e-04\n",
      "Epoch 353/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1175 - lr: 5.6719e-04\n",
      "Epoch 354/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1182 - lr: 5.6719e-04\n",
      "Epoch 355/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.1179 - lr: 5.6719e-04\n",
      "Epoch 356/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.1191 - lr: 5.6719e-04\n",
      "Epoch 357/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.1186 - lr: 5.6719e-04\n",
      "Epoch 358/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.1179 - lr: 5.6719e-04\n",
      "Epoch 359/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1195 - lr: 5.6719e-04\n",
      "Epoch 360/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1197 - lr: 5.6719e-04\n",
      "Epoch 361/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1191 - lr: 5.3883e-04\n",
      "Epoch 362/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1172 - lr: 5.3883e-04\n",
      "Epoch 363/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1160 - lr: 5.3883e-04\n",
      "Epoch 364/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1159 - lr: 5.3883e-04\n",
      "Epoch 365/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1192 - lr: 5.3883e-04\n",
      "Epoch 366/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1157 - lr: 5.3883e-04\n",
      "Epoch 367/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1147 - lr: 5.3883e-04\n",
      "Epoch 368/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1154 - lr: 5.3883e-04\n",
      "Epoch 369/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1157 - lr: 5.3883e-04\n",
      "Epoch 370/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1164 - lr: 5.3883e-04\n",
      "Epoch 371/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1148 - lr: 5.1189e-04\n",
      "Epoch 372/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1159 - lr: 5.1189e-04\n",
      "Epoch 373/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1144 - lr: 5.1189e-04\n",
      "Epoch 374/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1137 - lr: 5.1189e-04\n",
      "Epoch 375/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1147 - lr: 5.1189e-04\n",
      "Epoch 376/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1139 - lr: 5.1189e-04\n",
      "Epoch 377/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1139 - lr: 5.1189e-04\n",
      "Epoch 378/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1139 - lr: 5.1189e-04\n",
      "Epoch 379/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1159 - lr: 5.1189e-04\n",
      "Epoch 380/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1155 - lr: 5.1189e-04\n",
      "Epoch 381/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1121 - lr: 4.8629e-04\n",
      "Epoch 382/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1139 - lr: 4.8629e-04\n",
      "Epoch 383/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1154 - lr: 4.8629e-04\n",
      "Epoch 384/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1156 - lr: 4.8629e-04\n",
      "Epoch 385/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1157 - lr: 4.8629e-04\n",
      "Epoch 386/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1132 - lr: 4.8629e-04\n",
      "Epoch 387/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1124 - lr: 4.8629e-04\n",
      "Epoch 388/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1116 - lr: 4.8629e-04\n",
      "Epoch 389/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1120 - lr: 4.8629e-04\n",
      "Epoch 390/1000\n",
      "1114/1114 [==============================] - 15s 14ms/step - loss: 1.1147 - lr: 4.8629e-04\n",
      "Epoch 391/1000\n",
      "1114/1114 [==============================] - 17s 15ms/step - loss: 1.1111 - lr: 4.6198e-04\n",
      "Epoch 392/1000\n",
      "1114/1114 [==============================] - 18s 16ms/step - loss: 1.1108 - lr: 4.6198e-04\n",
      "Epoch 393/1000\n",
      "1114/1114 [==============================] - 18s 16ms/step - loss: 1.1101 - lr: 4.6198e-04\n",
      "Epoch 394/1000\n",
      "1114/1114 [==============================] - 16s 14ms/step - loss: 1.1129 - lr: 4.6198e-04\n",
      "Epoch 395/1000\n",
      "1114/1114 [==============================] - 15s 14ms/step - loss: 1.1137 - lr: 4.6198e-04\n",
      "Epoch 396/1000\n",
      "1114/1114 [==============================] - 15s 14ms/step - loss: 1.1102 - lr: 4.6198e-04\n",
      "Epoch 397/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.1137 - lr: 4.6198e-04\n",
      "Epoch 398/1000\n",
      "1114/1114 [==============================] - 17s 16ms/step - loss: 1.1091 - lr: 4.6198e-04\n",
      "Epoch 399/1000\n",
      "1114/1114 [==============================] - 15s 14ms/step - loss: 1.1088 - lr: 4.6198e-04\n",
      "Epoch 400/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.1091 - lr: 4.6198e-04\n",
      "Epoch 401/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.1105 - lr: 4.3888e-04\n",
      "Epoch 402/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.1133 - lr: 4.3888e-04\n",
      "Epoch 403/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.1094 - lr: 4.3888e-04\n",
      "Epoch 404/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.1079 - lr: 4.3888e-04\n",
      "Epoch 405/1000\n",
      "1114/1114 [==============================] - 14s 13ms/step - loss: 1.1082 - lr: 4.3888e-04\n",
      "Epoch 406/1000\n",
      "1114/1114 [==============================] - 15s 14ms/step - loss: 1.1089 - lr: 4.3888e-04\n",
      "Epoch 407/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.1078 - lr: 4.3888e-04\n",
      "Epoch 408/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.1085 - lr: 4.3888e-04\n",
      "Epoch 409/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.1083 - lr: 4.3888e-04\n",
      "Epoch 410/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.1074 - lr: 4.3888e-04\n",
      "Epoch 411/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.1070 - lr: 4.1693e-04\n",
      "Epoch 412/1000\n",
      "1114/1114 [==============================] - 17s 15ms/step - loss: 1.1078 - lr: 4.1693e-04\n",
      "Epoch 413/1000\n",
      "1114/1114 [==============================] - 16s 15ms/step - loss: 1.1077 - lr: 4.1693e-04\n",
      "Epoch 414/1000\n",
      "1114/1114 [==============================] - 15s 14ms/step - loss: 1.1087 - lr: 4.1693e-04\n",
      "Epoch 415/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.1071 - lr: 4.1693e-04\n",
      "Epoch 416/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.1074 - lr: 4.1693e-04\n",
      "Epoch 417/1000\n",
      "1114/1114 [==============================] - 15s 14ms/step - loss: 1.1075 - lr: 4.1693e-04\n",
      "Epoch 418/1000\n",
      "1114/1114 [==============================] - 15s 14ms/step - loss: 1.1075 - lr: 4.1693e-04\n",
      "Epoch 419/1000\n",
      "1114/1114 [==============================] - 17s 16ms/step - loss: 1.1075 - lr: 4.1693e-04\n",
      "Epoch 420/1000\n",
      "1114/1114 [==============================] - 17s 15ms/step - loss: 1.1103 - lr: 4.1693e-04\n",
      "Epoch 421/1000\n",
      "1114/1114 [==============================] - 18s 16ms/step - loss: 1.1071 - lr: 3.9609e-04\n",
      "Epoch 422/1000\n",
      "1114/1114 [==============================] - 18s 16ms/step - loss: 1.1044 - lr: 3.9609e-04\n",
      "Epoch 423/1000\n",
      "1114/1114 [==============================] - 17s 16ms/step - loss: 1.1048 - lr: 3.9609e-04\n",
      "Epoch 424/1000\n",
      "1114/1114 [==============================] - 18s 16ms/step - loss: 1.1072 - lr: 3.9609e-04\n",
      "Epoch 425/1000\n",
      "1114/1114 [==============================] - 18s 16ms/step - loss: 1.1053 - lr: 3.9609e-04\n",
      "Epoch 426/1000\n",
      "1114/1114 [==============================] - 22s 20ms/step - loss: 1.1052 - lr: 3.9609e-04\n",
      "Epoch 427/1000\n",
      "1114/1114 [==============================] - 19s 17ms/step - loss: 1.1058 - lr: 3.9609e-04\n",
      "Epoch 428/1000\n",
      "1114/1114 [==============================] - 19s 17ms/step - loss: 1.1061 - lr: 3.9609e-04\n",
      "Epoch 429/1000\n",
      "1114/1114 [==============================] - 18s 16ms/step - loss: 1.1053 - lr: 3.9609e-04\n",
      "Epoch 430/1000\n",
      "1114/1114 [==============================] - 18s 16ms/step - loss: 1.1047 - lr: 3.9609e-04\n",
      "Epoch 431/1000\n",
      "1114/1114 [==============================] - 18s 16ms/step - loss: 1.1048 - lr: 3.7628e-04\n",
      "Epoch 432/1000\n",
      "1114/1114 [==============================] - 18s 16ms/step - loss: 1.1035 - lr: 3.7628e-04\n",
      "Epoch 433/1000\n",
      "1114/1114 [==============================] - 18s 16ms/step - loss: 1.1035 - lr: 3.7628e-04\n",
      "Epoch 434/1000\n",
      "1114/1114 [==============================] - 18s 16ms/step - loss: 1.1040 - lr: 3.7628e-04\n",
      "Epoch 435/1000\n",
      "1114/1114 [==============================] - 18s 16ms/step - loss: 1.1046 - lr: 3.7628e-04\n",
      "Epoch 436/1000\n",
      "1114/1114 [==============================] - 18s 16ms/step - loss: 1.1029 - lr: 3.7628e-04\n",
      "Epoch 437/1000\n",
      "1114/1114 [==============================] - 18s 16ms/step - loss: 1.1026 - lr: 3.7628e-04\n",
      "Epoch 438/1000\n",
      "1114/1114 [==============================] - 18s 16ms/step - loss: 1.1026 - lr: 3.7628e-04\n",
      "Epoch 439/1000\n",
      "1114/1114 [==============================] - 18s 16ms/step - loss: 1.1032 - lr: 3.7628e-04\n",
      "Epoch 440/1000\n",
      "1114/1114 [==============================] - 18s 16ms/step - loss: 1.1047 - lr: 3.7628e-04\n",
      "Epoch 441/1000\n",
      "1114/1114 [==============================] - 18s 16ms/step - loss: 1.1020 - lr: 3.5747e-04\n",
      "Epoch 442/1000\n",
      "1114/1114 [==============================] - 18s 17ms/step - loss: 1.1022 - lr: 3.5747e-04\n",
      "Epoch 443/1000\n",
      "1114/1114 [==============================] - 18s 17ms/step - loss: 1.1025 - lr: 3.5747e-04\n",
      "Epoch 444/1000\n",
      "1114/1114 [==============================] - 18s 16ms/step - loss: 1.1016 - lr: 3.5747e-04\n",
      "Epoch 445/1000\n",
      "1114/1114 [==============================] - 19s 17ms/step - loss: 1.1025 - lr: 3.5747e-04\n",
      "Epoch 446/1000\n",
      "1114/1114 [==============================] - 19s 17ms/step - loss: 1.1019 - lr: 3.5747e-04\n",
      "Epoch 447/1000\n",
      "1114/1114 [==============================] - 18s 16ms/step - loss: 1.1019 - lr: 3.5747e-04\n",
      "Epoch 448/1000\n",
      "1114/1114 [==============================] - 18s 17ms/step - loss: 1.1020 - lr: 3.5747e-04\n",
      "Epoch 449/1000\n",
      "1114/1114 [==============================] - 18s 16ms/step - loss: 1.1027 - lr: 3.5747e-04\n",
      "Epoch 450/1000\n",
      "1114/1114 [==============================] - 17s 15ms/step - loss: 1.1044 - lr: 3.5747e-04\n",
      "Epoch 451/1000\n",
      "1114/1114 [==============================] - 18s 16ms/step - loss: 1.1017 - lr: 3.3960e-04\n",
      "Epoch 452/1000\n",
      "1114/1114 [==============================] - 19s 17ms/step - loss: 1.1022 - lr: 3.3960e-04\n",
      "Epoch 453/1000\n",
      "1114/1114 [==============================] - 19s 17ms/step - loss: 1.1015 - lr: 3.3960e-04\n",
      "Epoch 454/1000\n",
      "1114/1114 [==============================] - 19s 17ms/step - loss: 1.1001 - lr: 3.3960e-04\n",
      "Epoch 455/1000\n",
      "1114/1114 [==============================] - 20s 18ms/step - loss: 1.1008 - lr: 3.3960e-04\n",
      "Epoch 456/1000\n",
      "1114/1114 [==============================] - 20s 18ms/step - loss: 1.1010 - lr: 3.3960e-04\n",
      "Epoch 457/1000\n",
      "1114/1114 [==============================] - 20s 18ms/step - loss: 1.1007 - lr: 3.3960e-04\n",
      "Epoch 458/1000\n",
      "1114/1114 [==============================] - 19s 17ms/step - loss: 1.1014 - lr: 3.3960e-04\n",
      "Epoch 459/1000\n",
      "1114/1114 [==============================] - 19s 17ms/step - loss: 1.1026 - lr: 3.3960e-04\n",
      "Epoch 460/1000\n",
      "1114/1114 [==============================] - 18s 17ms/step - loss: 1.1050 - lr: 3.3960e-04\n",
      "Epoch 461/1000\n",
      "1114/1114 [==============================] - 19s 17ms/step - loss: 1.1002 - lr: 3.2262e-04\n",
      "Epoch 462/1000\n",
      "1114/1114 [==============================] - 19s 17ms/step - loss: 1.0996 - lr: 3.2262e-04\n",
      "Epoch 463/1000\n",
      "1114/1114 [==============================] - 19s 17ms/step - loss: 1.0994 - lr: 3.2262e-04\n",
      "Epoch 464/1000\n",
      "1114/1114 [==============================] - 19s 17ms/step - loss: 1.0998 - lr: 3.2262e-04\n",
      "Epoch 465/1000\n",
      "1114/1114 [==============================] - 19s 17ms/step - loss: 1.0999 - lr: 3.2262e-04\n",
      "Epoch 466/1000\n",
      "1114/1114 [==============================] - 19s 17ms/step - loss: 1.0994 - lr: 3.2262e-04\n",
      "Epoch 467/1000\n",
      "1114/1114 [==============================] - 19s 17ms/step - loss: 1.0987 - lr: 3.2262e-04\n",
      "Epoch 468/1000\n",
      "1114/1114 [==============================] - 19s 17ms/step - loss: 1.1006 - lr: 3.2262e-04\n",
      "Epoch 469/1000\n",
      "1114/1114 [==============================] - 19s 17ms/step - loss: 1.0999 - lr: 3.2262e-04\n",
      "Epoch 470/1000\n",
      "1114/1114 [==============================] - 19s 17ms/step - loss: 1.0990 - lr: 3.2262e-04\n",
      "Epoch 471/1000\n",
      "1114/1114 [==============================] - 19s 17ms/step - loss: 1.0992 - lr: 3.0649e-04\n",
      "Epoch 472/1000\n",
      "1114/1114 [==============================] - 20s 18ms/step - loss: 1.0986 - lr: 3.0649e-04\n",
      "Epoch 473/1000\n",
      "1114/1114 [==============================] - 20s 18ms/step - loss: 1.0990 - lr: 3.0649e-04\n",
      "Epoch 474/1000\n",
      "1114/1114 [==============================] - 20s 18ms/step - loss: 1.0980 - lr: 3.0649e-04\n",
      "Epoch 475/1000\n",
      "1114/1114 [==============================] - 20s 18ms/step - loss: 1.0986 - lr: 3.0649e-04\n",
      "Epoch 476/1000\n",
      "1114/1114 [==============================] - 20s 18ms/step - loss: 1.0978 - lr: 3.0649e-04\n",
      "Epoch 477/1000\n",
      "1114/1114 [==============================] - 20s 18ms/step - loss: 1.0978 - lr: 3.0649e-04\n",
      "Epoch 478/1000\n",
      "1114/1114 [==============================] - 20s 18ms/step - loss: 1.0973 - lr: 3.0649e-04\n",
      "Epoch 479/1000\n",
      "1114/1114 [==============================] - 20s 18ms/step - loss: 1.0978 - lr: 3.0649e-04\n",
      "Epoch 480/1000\n",
      "1114/1114 [==============================] - 20s 18ms/step - loss: 1.0988 - lr: 3.0649e-04\n",
      "Epoch 481/1000\n",
      "1114/1114 [==============================] - 20s 18ms/step - loss: 1.0982 - lr: 2.9116e-04\n",
      "Epoch 482/1000\n",
      "1114/1114 [==============================] - 20s 18ms/step - loss: 1.0970 - lr: 2.9116e-04\n",
      "Epoch 483/1000\n",
      "1114/1114 [==============================] - 19s 17ms/step - loss: 1.0971 - lr: 2.9116e-04\n",
      "Epoch 484/1000\n",
      "1114/1114 [==============================] - 19s 17ms/step - loss: 1.0967 - lr: 2.9116e-04\n",
      "Epoch 485/1000\n",
      "1114/1114 [==============================] - 20s 18ms/step - loss: 1.0975 - lr: 2.9116e-04\n",
      "Epoch 486/1000\n",
      "1114/1114 [==============================] - 18s 17ms/step - loss: 1.0974 - lr: 2.9116e-04\n",
      "Epoch 487/1000\n",
      "1114/1114 [==============================] - 15s 14ms/step - loss: 1.0976 - lr: 2.9116e-04\n",
      "Epoch 488/1000\n",
      "1114/1114 [==============================] - 19s 17ms/step - loss: 1.0970 - lr: 2.9116e-04\n",
      "Epoch 489/1000\n",
      "1114/1114 [==============================] - 19s 17ms/step - loss: 1.0965 - lr: 2.9116e-04\n",
      "Epoch 490/1000\n",
      "1114/1114 [==============================] - 19s 17ms/step - loss: 1.0976 - lr: 2.9116e-04\n",
      "Epoch 491/1000\n",
      "1114/1114 [==============================] - 20s 18ms/step - loss: 1.0971 - lr: 2.7660e-04\n",
      "Epoch 492/1000\n",
      "1114/1114 [==============================] - 20s 18ms/step - loss: 1.0967 - lr: 2.7660e-04\n",
      "Epoch 493/1000\n",
      "1114/1114 [==============================] - 20s 18ms/step - loss: 1.0962 - lr: 2.7660e-04\n",
      "Epoch 494/1000\n",
      "1114/1114 [==============================] - 19s 17ms/step - loss: 1.0975 - lr: 2.7660e-04\n",
      "Epoch 495/1000\n",
      "1114/1114 [==============================] - 20s 18ms/step - loss: 1.0963 - lr: 2.7660e-04\n",
      "Epoch 496/1000\n",
      "1114/1114 [==============================] - 19s 17ms/step - loss: 1.0956 - lr: 2.7660e-04\n",
      "Epoch 497/1000\n",
      "1114/1114 [==============================] - 19s 17ms/step - loss: 1.0955 - lr: 2.7660e-04\n",
      "Epoch 498/1000\n",
      "1114/1114 [==============================] - 19s 17ms/step - loss: 1.0974 - lr: 2.7660e-04\n",
      "Epoch 499/1000\n",
      "1114/1114 [==============================] - 19s 17ms/step - loss: 1.0957 - lr: 2.7660e-04\n",
      "Epoch 500/1000\n",
      "1114/1114 [==============================] - 19s 17ms/step - loss: 1.0964 - lr: 2.7660e-04\n",
      "Epoch 501/1000\n",
      "1114/1114 [==============================] - 19s 17ms/step - loss: 1.0964 - lr: 2.6277e-04\n",
      "Epoch 502/1000\n",
      "1114/1114 [==============================] - 19s 17ms/step - loss: 1.0951 - lr: 2.6277e-04\n",
      "Epoch 503/1000\n",
      "1114/1114 [==============================] - 20s 18ms/step - loss: 1.0949 - lr: 2.6277e-04\n",
      "Epoch 504/1000\n",
      "1114/1114 [==============================] - 19s 17ms/step - loss: 1.0953 - lr: 2.6277e-04\n",
      "Epoch 505/1000\n",
      "1114/1114 [==============================] - 19s 17ms/step - loss: 1.0951 - lr: 2.6277e-04\n",
      "Epoch 506/1000\n",
      "1114/1114 [==============================] - 19s 17ms/step - loss: 1.0948 - lr: 2.6277e-04\n",
      "Epoch 507/1000\n",
      "1114/1114 [==============================] - 20s 18ms/step - loss: 1.0952 - lr: 2.6277e-04\n",
      "Epoch 508/1000\n",
      "1114/1114 [==============================] - 19s 17ms/step - loss: 1.0960 - lr: 2.6277e-04\n",
      "Epoch 509/1000\n",
      "1114/1114 [==============================] - 18s 16ms/step - loss: 1.0967 - lr: 2.6277e-04\n",
      "Epoch 510/1000\n",
      "1114/1114 [==============================] - 17s 16ms/step - loss: 1.0956 - lr: 2.6277e-04\n",
      "Epoch 511/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.0941 - lr: 2.4963e-04\n",
      "Epoch 512/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.0946 - lr: 2.4963e-04\n",
      "Epoch 513/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.0949 - lr: 2.4963e-04\n",
      "Epoch 514/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.0946 - lr: 2.4963e-04\n",
      "Epoch 515/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.0937 - lr: 2.4963e-04\n",
      "Epoch 516/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.0935 - lr: 2.4963e-04\n",
      "Epoch 517/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.0935 - lr: 2.4963e-04\n",
      "Epoch 518/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.0944 - lr: 2.4963e-04\n",
      "Epoch 519/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.0938 - lr: 2.4963e-04\n",
      "Epoch 520/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.0944 - lr: 2.4963e-04\n",
      "Epoch 521/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.0937 - lr: 2.3715e-04\n",
      "Epoch 522/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.0931 - lr: 2.3715e-04\n",
      "Epoch 523/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.0928 - lr: 2.3715e-04\n",
      "Epoch 524/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.0931 - lr: 2.3715e-04\n",
      "Epoch 525/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.0935 - lr: 2.3715e-04\n",
      "Epoch 526/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.0936 - lr: 2.3715e-04\n",
      "Epoch 527/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.0930 - lr: 2.3715e-04\n",
      "Epoch 528/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.0944 - lr: 2.3715e-04\n",
      "Epoch 529/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.0933 - lr: 2.3715e-04\n",
      "Epoch 530/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.0933 - lr: 2.3715e-04\n",
      "Epoch 531/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.0927 - lr: 2.2529e-04\n",
      "Epoch 532/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.0922 - lr: 2.2529e-04\n",
      "Epoch 533/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.0925 - lr: 2.2529e-04\n",
      "Epoch 534/1000\n",
      "1114/1114 [==============================] - 15s 14ms/step - loss: 1.0919 - lr: 2.2529e-04\n",
      "Epoch 535/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.0919 - lr: 2.2529e-04\n",
      "Epoch 536/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.0919 - lr: 2.2529e-04\n",
      "Epoch 537/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.0928 - lr: 2.2529e-04\n",
      "Epoch 538/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.0940 - lr: 2.2529e-04\n",
      "Epoch 539/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.0946 - lr: 2.2529e-04\n",
      "Epoch 540/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.0920 - lr: 2.2529e-04\n",
      "Epoch 541/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.0915 - lr: 2.1403e-04\n",
      "Epoch 542/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.0922 - lr: 2.1403e-04\n",
      "Epoch 543/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.0923 - lr: 2.1403e-04\n",
      "Epoch 544/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.0930 - lr: 2.1403e-04\n",
      "Epoch 545/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.0920 - lr: 2.1403e-04\n",
      "Epoch 546/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.0909 - lr: 2.1403e-04\n",
      "Epoch 547/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.0912 - lr: 2.1403e-04\n",
      "Epoch 548/1000\n",
      "1114/1114 [==============================] - 15s 14ms/step - loss: 1.0911 - lr: 2.1403e-04\n",
      "Epoch 549/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.0916 - lr: 2.1403e-04\n",
      "Epoch 550/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.0911 - lr: 2.1403e-04\n",
      "Epoch 551/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.0907 - lr: 2.0333e-04\n",
      "Epoch 552/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.0910 - lr: 2.0333e-04\n",
      "Epoch 553/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.0910 - lr: 2.0333e-04\n",
      "Epoch 554/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.0905 - lr: 2.0333e-04\n",
      "Epoch 555/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.0905 - lr: 2.0333e-04\n",
      "Epoch 556/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.0904 - lr: 2.0333e-04\n",
      "Epoch 557/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.0904 - lr: 2.0333e-04\n",
      "Epoch 558/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.0911 - lr: 2.0333e-04\n",
      "Epoch 559/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.0905 - lr: 2.0333e-04\n",
      "Epoch 560/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.0915 - lr: 2.0333e-04\n",
      "Epoch 561/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.0905 - lr: 1.9316e-04\n",
      "Epoch 562/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.0901 - lr: 1.9316e-04\n",
      "Epoch 563/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.0905 - lr: 1.9316e-04\n",
      "Epoch 564/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.0903 - lr: 1.9316e-04\n",
      "Epoch 565/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.0900 - lr: 1.9316e-04\n",
      "Epoch 566/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.0897 - lr: 1.9316e-04\n",
      "Epoch 567/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.0896 - lr: 1.9316e-04\n",
      "Epoch 568/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.0899 - lr: 1.9316e-04\n",
      "Epoch 569/1000\n",
      "1114/1114 [==============================] - 15s 14ms/step - loss: 1.0897 - lr: 1.9316e-04\n",
      "Epoch 570/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.0893 - lr: 1.9316e-04\n",
      "Epoch 571/1000\n",
      "1114/1114 [==============================] - 15s 13ms/step - loss: 1.0891 - lr: 1.8350e-04\n",
      "Epoch 572/1000\n",
      " 159/1114 [===>..........................] - ETA: 12s - loss: 1.0922"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[32], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mbest_model\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtrain_dataset\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mepochs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m1000\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mverbose\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcallbacks\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m[\u001b[49m\u001b[43mtensorboard_callback\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlr_decay_callback\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcheckpoint_callback\u001b[49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/miniconda3/envs/tf/lib/python3.9/site-packages/keras/utils/traceback_utils.py:65\u001b[0m, in \u001b[0;36mfilter_traceback.<locals>.error_handler\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m     63\u001b[0m filtered_tb \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m     64\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m---> 65\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     66\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m     67\u001b[0m     filtered_tb \u001b[38;5;241m=\u001b[39m _process_traceback_frames(e\u001b[38;5;241m.\u001b[39m__traceback__)\n",
      "File \u001b[0;32m~/miniconda3/envs/tf/lib/python3.9/site-packages/keras/engine/training.py:1685\u001b[0m, in \u001b[0;36mModel.fit\u001b[0;34m(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing)\u001b[0m\n\u001b[1;32m   1677\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m tf\u001b[38;5;241m.\u001b[39mprofiler\u001b[38;5;241m.\u001b[39mexperimental\u001b[38;5;241m.\u001b[39mTrace(\n\u001b[1;32m   1678\u001b[0m     \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtrain\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m   1679\u001b[0m     epoch_num\u001b[38;5;241m=\u001b[39mepoch,\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m   1682\u001b[0m     _r\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m,\n\u001b[1;32m   1683\u001b[0m ):\n\u001b[1;32m   1684\u001b[0m     callbacks\u001b[38;5;241m.\u001b[39mon_train_batch_begin(step)\n\u001b[0;32m-> 1685\u001b[0m     tmp_logs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain_function\u001b[49m\u001b[43m(\u001b[49m\u001b[43miterator\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1686\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m data_handler\u001b[38;5;241m.\u001b[39mshould_sync:\n\u001b[1;32m   1687\u001b[0m         context\u001b[38;5;241m.\u001b[39masync_wait()\n",
      "File \u001b[0;32m~/miniconda3/envs/tf/lib/python3.9/site-packages/tensorflow/python/util/traceback_utils.py:150\u001b[0m, in \u001b[0;36mfilter_traceback.<locals>.error_handler\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    148\u001b[0m filtered_tb \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m    149\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 150\u001b[0m   \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    151\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m    152\u001b[0m   filtered_tb \u001b[38;5;241m=\u001b[39m _process_traceback_frames(e\u001b[38;5;241m.\u001b[39m__traceback__)\n",
      "File \u001b[0;32m~/miniconda3/envs/tf/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py:894\u001b[0m, in \u001b[0;36mFunction.__call__\u001b[0;34m(self, *args, **kwds)\u001b[0m\n\u001b[1;32m    891\u001b[0m compiler \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mxla\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_jit_compile \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mnonXla\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m    893\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m OptionalXlaContext(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_jit_compile):\n\u001b[0;32m--> 894\u001b[0m   result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwds\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    896\u001b[0m new_tracing_count \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mexperimental_get_tracing_count()\n\u001b[1;32m    897\u001b[0m without_tracing \u001b[38;5;241m=\u001b[39m (tracing_count \u001b[38;5;241m==\u001b[39m new_tracing_count)\n",
      "File \u001b[0;32m~/miniconda3/envs/tf/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py:926\u001b[0m, in \u001b[0;36mFunction._call\u001b[0;34m(self, *args, **kwds)\u001b[0m\n\u001b[1;32m    923\u001b[0m   \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_lock\u001b[38;5;241m.\u001b[39mrelease()\n\u001b[1;32m    924\u001b[0m   \u001b[38;5;66;03m# In this case we have created variables on the first call, so we run the\u001b[39;00m\n\u001b[1;32m    925\u001b[0m   \u001b[38;5;66;03m# defunned version which is guaranteed to never create variables.\u001b[39;00m\n\u001b[0;32m--> 926\u001b[0m   \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_no_variable_creation_fn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwds\u001b[49m\u001b[43m)\u001b[49m  \u001b[38;5;66;03m# pylint: disable=not-callable\u001b[39;00m\n\u001b[1;32m    927\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_variable_creation_fn \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m    928\u001b[0m   \u001b[38;5;66;03m# Release the lock early so that multiple threads can perform the call\u001b[39;00m\n\u001b[1;32m    929\u001b[0m   \u001b[38;5;66;03m# in parallel.\u001b[39;00m\n\u001b[1;32m    930\u001b[0m   \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_lock\u001b[38;5;241m.\u001b[39mrelease()\n",
      "File \u001b[0;32m~/miniconda3/envs/tf/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compiler.py:143\u001b[0m, in \u001b[0;36mTracingCompiler.__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m    140\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_lock:\n\u001b[1;32m    141\u001b[0m   (concrete_function,\n\u001b[1;32m    142\u001b[0m    filtered_flat_args) \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_maybe_define_function(args, kwargs)\n\u001b[0;32m--> 143\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mconcrete_function\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_flat\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    144\u001b[0m \u001b[43m    \u001b[49m\u001b[43mfiltered_flat_args\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcaptured_inputs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mconcrete_function\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcaptured_inputs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/miniconda3/envs/tf/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/monomorphic_function.py:1757\u001b[0m, in \u001b[0;36mConcreteFunction._call_flat\u001b[0;34m(self, args, captured_inputs, cancellation_manager)\u001b[0m\n\u001b[1;32m   1753\u001b[0m possible_gradient_type \u001b[38;5;241m=\u001b[39m gradients_util\u001b[38;5;241m.\u001b[39mPossibleTapeGradientTypes(args)\n\u001b[1;32m   1754\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m (possible_gradient_type \u001b[38;5;241m==\u001b[39m gradients_util\u001b[38;5;241m.\u001b[39mPOSSIBLE_GRADIENT_TYPES_NONE\n\u001b[1;32m   1755\u001b[0m     \u001b[38;5;129;01mand\u001b[39;00m executing_eagerly):\n\u001b[1;32m   1756\u001b[0m   \u001b[38;5;66;03m# No tape is watching; skip to running the function.\u001b[39;00m\n\u001b[0;32m-> 1757\u001b[0m   \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_build_call_outputs(\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_inference_function\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcall\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m   1758\u001b[0m \u001b[43m      \u001b[49m\u001b[43mctx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcancellation_manager\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcancellation_manager\u001b[49m\u001b[43m)\u001b[49m)\n\u001b[1;32m   1759\u001b[0m forward_backward \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_select_forward_and_backward_functions(\n\u001b[1;32m   1760\u001b[0m     args,\n\u001b[1;32m   1761\u001b[0m     possible_gradient_type,\n\u001b[1;32m   1762\u001b[0m     executing_eagerly)\n\u001b[1;32m   1763\u001b[0m forward_function, args_with_tangents \u001b[38;5;241m=\u001b[39m forward_backward\u001b[38;5;241m.\u001b[39mforward()\n",
      "File \u001b[0;32m~/miniconda3/envs/tf/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/monomorphic_function.py:381\u001b[0m, in \u001b[0;36m_EagerDefinedFunction.call\u001b[0;34m(self, ctx, args, cancellation_manager)\u001b[0m\n\u001b[1;32m    379\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m _InterpolateFunctionError(\u001b[38;5;28mself\u001b[39m):\n\u001b[1;32m    380\u001b[0m   \u001b[38;5;28;01mif\u001b[39;00m cancellation_manager \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m--> 381\u001b[0m     outputs \u001b[38;5;241m=\u001b[39m \u001b[43mexecute\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mexecute\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    382\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;28;43mstr\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msignature\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mname\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    383\u001b[0m \u001b[43m        \u001b[49m\u001b[43mnum_outputs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_num_outputs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    384\u001b[0m \u001b[43m        \u001b[49m\u001b[43minputs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    385\u001b[0m \u001b[43m        \u001b[49m\u001b[43mattrs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattrs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    386\u001b[0m \u001b[43m        \u001b[49m\u001b[43mctx\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mctx\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    387\u001b[0m   \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m    388\u001b[0m     outputs \u001b[38;5;241m=\u001b[39m execute\u001b[38;5;241m.\u001b[39mexecute_with_cancellation(\n\u001b[1;32m    389\u001b[0m         \u001b[38;5;28mstr\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msignature\u001b[38;5;241m.\u001b[39mname),\n\u001b[1;32m    390\u001b[0m         num_outputs\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_num_outputs,\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    393\u001b[0m         ctx\u001b[38;5;241m=\u001b[39mctx,\n\u001b[1;32m    394\u001b[0m         cancellation_manager\u001b[38;5;241m=\u001b[39mcancellation_manager)\n",
      "File \u001b[0;32m~/miniconda3/envs/tf/lib/python3.9/site-packages/tensorflow/python/eager/execute.py:52\u001b[0m, in \u001b[0;36mquick_execute\u001b[0;34m(op_name, num_outputs, inputs, attrs, ctx, name)\u001b[0m\n\u001b[1;32m     50\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m     51\u001b[0m   ctx\u001b[38;5;241m.\u001b[39mensure_initialized()\n\u001b[0;32m---> 52\u001b[0m   tensors \u001b[38;5;241m=\u001b[39m \u001b[43mpywrap_tfe\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mTFE_Py_Execute\u001b[49m\u001b[43m(\u001b[49m\u001b[43mctx\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_handle\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdevice_name\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mop_name\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m     53\u001b[0m \u001b[43m                                      \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mattrs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnum_outputs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     54\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m core\u001b[38;5;241m.\u001b[39m_NotOkStatusException \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m     55\u001b[0m   \u001b[38;5;28;01mif\u001b[39;00m name \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "best_model.fit(\n",
    "    train_dataset,\n",
    "    epochs=1000,\n",
    "    verbose=1,\n",
    "    callbacks=[tensorboard_callback, lr_decay_callback, checkpoint_callback],\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f37e02fa-ee24-4662-b98e-a1189593b363",
   "metadata": {},
   "source": [
    "## Load Model if already trained"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "90c2cff6-9862-4a36-9624-a09ca0c36627",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Error in loading the saved optimizer state. As a result, your model is starting with a freshly initialized optimizer.\n"
     ]
    }
   ],
   "source": [
    "# best_model.save(\"202309_english_place_names.keras\")\n",
    "best_model = load_model(\"20230909_english_place_names.keras\")  # best one it seems\n",
    "# best_model = load_model(\"LSTM_places_64.keras\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "eb0197df-b444-4b39-b86b-0b9c8641bc3c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"sequential\"\n",
      "_________________________________________________________________\n",
      " Layer (type)                Output Shape              Param #   \n",
      "=================================================================\n",
      " lstm (LSTM)                 (None, 114)               69768     \n",
      "                                                                 \n",
      " dense (Dense)               (None, 38)                4370      \n",
      "                                                                 \n",
      "=================================================================\n",
      "Total params: 74,138\n",
      "Trainable params: 74,138\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "best_model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "f40dc1ee-36f9-43fc-8c4f-34ad59561d03",
   "metadata": {},
   "outputs": [],
   "source": [
    "start = \"abcdefghijklmnopqrstuvwxyz\"\n",
    "letters = [c for c in start]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bc0e95fd-b6d7-410d-8668-f42bc4a3aeca",
   "metadata": {},
   "source": [
    "## How unique are generations?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "9d83a71d-26d4-4d2e-b96b-245259cd6269",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|                                                                                    | 0/26 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "generating words\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  4%|██▉                                                                         | 1/26 [00:03<01:32,  3.68s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "generating words\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  8%|█████▊                                                                      | 2/26 [00:08<01:42,  4.25s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "generating words\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 12%|████████▊                                                                   | 3/26 [00:12<01:35,  4.16s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "generating words\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 15%|███████████▋                                                                | 4/26 [00:17<01:38,  4.46s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "generating words\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 19%|██████████████▌                                                             | 5/26 [00:21<01:31,  4.34s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "generating words\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 23%|█████████████████▌                                                          | 6/26 [00:24<01:17,  3.85s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "generating words\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 27%|████████████████████▍                                                       | 7/26 [00:27<01:11,  3.77s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "generating words\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 31%|███████████████████████▍                                                    | 8/26 [00:31<01:06,  3.71s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "generating words\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 35%|██████████████████████████▎                                                 | 9/26 [00:34<00:59,  3.53s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "generating words\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 38%|████████████████████████████▊                                              | 10/26 [00:37<00:52,  3.28s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "generating words\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 42%|███████████████████████████████▋                                           | 11/26 [00:41<00:53,  3.57s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "generating words\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 46%|██████████████████████████████████▌                                        | 12/26 [00:46<00:54,  3.86s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "generating words\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 50%|█████████████████████████████████████▌                                     | 13/26 [00:49<00:49,  3.81s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "generating words\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 54%|████████████████████████████████████████▍                                  | 14/26 [00:55<00:52,  4.39s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "generating words\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 58%|███████████████████████████████████████████▎                               | 15/26 [00:59<00:45,  4.16s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "generating words\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 62%|██████████████████████████████████████████████▏                            | 16/26 [01:03<00:40,  4.08s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "generating words\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 65%|█████████████████████████████████████████████████                          | 17/26 [01:06<00:34,  3.84s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "generating words\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 69%|███████████████████████████████████████████████████▉                       | 18/26 [01:09<00:29,  3.63s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "generating words\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 73%|██████████████████████████████████████████████████████▊                    | 19/26 [01:13<00:27,  3.89s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "generating words\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 77%|█████████████████████████████████████████████████████████▋                 | 20/26 [01:17<00:22,  3.71s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "generating words\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 81%|████████████████████████████████████████████████████████████▌              | 21/26 [01:20<00:18,  3.66s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "generating words\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 85%|███████████████████████████████████████████████████████████████▍           | 22/26 [01:23<00:13,  3.41s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "generating words\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 88%|██████████████████████████████████████████████████████████████████▎        | 23/26 [01:27<00:10,  3.42s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "generating words\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 92%|█████████████████████████████████████████████████████████████████████▏     | 24/26 [01:29<00:06,  3.09s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "generating words\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 96%|████████████████████████████████████████████████████████████████████████   | 25/26 [01:33<00:03,  3.35s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "generating words\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████| 26/26 [01:36<00:00,  3.69s/it]\n"
     ]
    }
   ],
   "source": [
    "# do 250 generations per letter\n",
    "results = []\n",
    "for l in tqdm(letters):\n",
    "    res = generate_words(\n",
    "        best_model,\n",
    "        vocab_size,\n",
    "        max_len,\n",
    "        idx_to_char,\n",
    "        char_to_idx,\n",
    "        5,\n",
    "        temperature=1,\n",
    "        seed_word=l,\n",
    "    )\n",
    "    results += res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "9ea3e7af-5f9e-46c6-baaa-4c913021b871",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "generating words\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[':t rasheston',\n",
       " '& ountewston',\n",
       " 'jarvis end ed',\n",
       " 'gerborough with redland',\n",
       " 'north raydon',\n",
       " 'zeals and haresford ed',\n",
       " ')',\n",
       " 'low harwood street',\n",
       " ':t puconhill dale',\n",
       " 'charlestock']"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "generate_words(best_model, vocab_size, max_len, idx_to_char, char_to_idx, 10, temperature=1, seed_word=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "95afda08-3075-46ab-bf81-0bb5815d425f",
   "metadata": {},
   "outputs": [],
   "source": [
    "df[\"place_name\"] = df[\"place_name\"].str.lower()\n",
    "df_check = df.set_index(keys=\"place_name\", inplace=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "832cae07-6556-465d-b094-47038d3e1cea",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "starting with the letter 'a' there are 1661 words in the training set\n"
     ]
    }
   ],
   "source": [
    "#how many unique words can I make?\n",
    "starting_letter = \"a\"\n",
    "num_in_training_data = df_check.filter(regex=f\"^{starting_letter}.*\", axis=0).shape[0]\n",
    "print(f\"starting with the letter '{starting_letter}' there are {num_in_training_data} words in the training set\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "01316982-c09f-4b48-8d23-e804afb82ca7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "starting with 1660 words\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|                                                                                    | 0/200 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "generating words\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|▎                                                                         | 1/200 [01:01<3:23:39, 61.41s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 0\n",
      "generating words\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  1%|▋                                                                         | 2/200 [02:07<3:31:22, 64.05s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "100 89\n",
      "generating words\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  2%|█                                                                         | 3/200 [03:06<3:23:02, 61.84s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "200 175\n",
      "generating words\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  2%|█▍                                                                        | 4/200 [04:07<3:21:16, 61.62s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "300 258\n",
      "generating words\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  2%|█▊                                                                        | 5/200 [05:06<3:16:48, 60.56s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "400 336\n",
      "generating words\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  3%|██▏                                                                       | 6/200 [06:09<3:18:30, 61.39s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "500 422\n",
      "generating words\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  4%|██▌                                                                       | 7/200 [07:07<3:13:35, 60.18s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "600 511\n",
      "generating words\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  4%|██▉                                                                       | 8/200 [08:07<3:13:12, 60.38s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "700 594\n",
      "generating words\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  4%|███▎                                                                      | 9/200 [09:09<3:13:48, 60.88s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "800 677\n",
      "generating words\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  5%|███▋                                                                     | 10/200 [10:12<3:14:20, 61.37s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "900 758\n",
      "generating words\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  6%|████                                                                     | 11/200 [11:14<3:13:33, 61.45s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1000 833\n",
      "generating words\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  6%|████▍                                                                    | 12/200 [12:12<3:09:59, 60.63s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1100 913\n",
      "generating words\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  6%|████▋                                                                    | 13/200 [13:08<3:04:38, 59.24s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1200 993\n",
      "generating words\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  7%|█████                                                                    | 14/200 [14:03<2:59:00, 57.75s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1300 1067\n",
      "generating words\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  8%|█████▍                                                                   | 15/200 [15:02<2:59:52, 58.34s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1400 1140\n",
      "generating words\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  8%|█████▊                                                                   | 16/200 [16:07<3:05:02, 60.34s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1500 1219\n",
      "generating words\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  8%|██████▏                                                                  | 17/200 [17:07<3:03:44, 60.24s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1600 1301\n",
      "generating words\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  9%|██████▌                                                                  | 18/200 [18:08<3:03:29, 60.49s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1700 1374\n",
      "generating words\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 10%|██████▉                                                                  | 19/200 [19:06<2:59:39, 59.56s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1800 1458\n",
      "generating words\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 10%|███████▎                                                                 | 20/200 [20:04<2:57:32, 59.18s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1900 1531\n",
      "generating words\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 10%|███████▋                                                                 | 21/200 [21:00<2:53:32, 58.17s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2000 1616\n",
      "generating words\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 11%|████████                                                                 | 22/200 [22:00<2:54:39, 58.88s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2100 1693\n",
      "generating words\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 12%|████████▍                                                                | 23/200 [23:01<2:55:35, 59.52s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2200 1770\n",
      "generating words\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 12%|████████▊                                                                | 24/200 [24:01<2:54:52, 59.62s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2300 1846\n",
      "generating words\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 12%|█████████▏                                                               | 25/200 [25:02<2:54:23, 59.79s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2400 1923\n",
      "generating words\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 13%|█████████▍                                                               | 26/200 [26:01<2:52:46, 59.58s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2500 1991\n",
      "generating words\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 14%|█████████▊                                                               | 27/200 [27:00<2:51:28, 59.47s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2600 2069\n",
      "generating words\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 14%|██████████▏                                                              | 28/200 [28:00<2:51:31, 59.83s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2700 2145\n",
      "generating words\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 14%|██████████▌                                                              | 29/200 [29:00<2:50:02, 59.66s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2800 2221\n",
      "generating words\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 15%|██████████▉                                                              | 30/200 [29:59<2:49:03, 59.67s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2900 2283\n",
      "generating words\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 16%|███████████▎                                                             | 31/200 [30:55<2:44:24, 58.37s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "3000 2355\n",
      "generating words\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 16%|███████████▋                                                             | 32/200 [31:55<2:44:59, 58.93s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "3100 2420\n",
      "generating words\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 16%|████████████                                                             | 33/200 [32:57<2:46:42, 59.90s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "3200 2496\n",
      "generating words\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 17%|████████████▍                                                            | 34/200 [33:58<2:46:05, 60.04s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "3300 2567\n",
      "generating words\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 18%|████████████▊                                                            | 35/200 [34:59<2:45:53, 60.33s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "3400 2643\n",
      "generating words\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 18%|█████████████▏                                                           | 36/200 [35:56<2:42:24, 59.42s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "3500 2716\n",
      "generating words\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 18%|█████████████▌                                                           | 37/200 [36:59<2:44:50, 60.68s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "3600 2795\n",
      "generating words\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 19%|█████████████▊                                                           | 38/200 [37:55<2:39:55, 59.23s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "3700 2872\n",
      "generating words\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 20%|██████████████▏                                                          | 39/200 [38:51<2:36:00, 58.14s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "3800 2942\n",
      "generating words\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 20%|██████████████▌                                                          | 40/200 [39:50<2:35:27, 58.30s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "3900 3007\n",
      "generating words\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 20%|██████████████▉                                                          | 41/200 [40:49<2:35:29, 58.67s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "4000 3086\n",
      "generating words\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 21%|███████████████▎                                                         | 42/200 [41:46<2:33:14, 58.19s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "4100 3167\n",
      "generating words\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 22%|███████████████▋                                                         | 43/200 [42:48<2:34:51, 59.18s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "4200 3232\n",
      "generating words\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 22%|████████████████                                                         | 44/200 [43:48<2:34:31, 59.44s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "4300 3302\n",
      "generating words\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 22%|████████████████▍                                                        | 45/200 [44:45<2:31:30, 58.65s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "4400 3374\n",
      "generating words\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 23%|████████████████▊                                                        | 46/200 [45:43<2:30:42, 58.72s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "4500 3452\n",
      "generating words\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 24%|█████████████████▏                                                       | 47/200 [46:41<2:29:11, 58.51s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "4600 3526\n",
      "generating words\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 24%|█████████████████▌                                                       | 48/200 [47:42<2:29:30, 59.02s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "4700 3597\n",
      "generating words\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 24%|█████████████████▉                                                       | 49/200 [48:43<2:30:10, 59.67s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "4800 3673\n",
      "generating words\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 25%|██████████████████▎                                                      | 50/200 [49:42<2:28:28, 59.39s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "4900 3750\n",
      "generating words\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 26%|██████████████████▌                                                      | 51/200 [50:42<2:28:21, 59.74s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "5000 3826\n",
      "generating words\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 26%|██████████████████▉                                                      | 52/200 [51:39<2:25:27, 58.97s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "5100 3901\n",
      "generating words\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 26%|███████████████████▎                                                     | 53/200 [52:37<2:23:32, 58.59s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "5200 3971\n",
      "generating words\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 27%|███████████████████▋                                                     | 54/200 [53:35<2:21:54, 58.32s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "5300 4037\n",
      "generating words\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 28%|████████████████████                                                     | 55/200 [54:34<2:21:31, 58.56s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "5400 4103\n",
      "generating words\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 28%|████████████████████▍                                                    | 56/200 [55:38<2:24:45, 60.32s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "5500 4172\n",
      "generating words\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 28%|████████████████████▊                                                    | 57/200 [56:35<2:21:19, 59.30s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "5600 4247\n",
      "generating words\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 29%|█████████████████████▏                                                   | 58/200 [57:36<2:21:41, 59.87s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "5700 4318\n",
      "generating words\n"
     ]
    }
   ],
   "source": [
    "num_generated = [0]\n",
    "num_unique = [0]\n",
    "runs = 200\n",
    "num_per_batch=100\n",
    "set_of_original_words = set(df_check.filter(regex=f\"^{starting_letter}.*\", axis=0).index)\n",
    "orig_word_count = len(set_of_original_words)\n",
    "print(f\"starting with {orig_word_count} words\")\n",
    "for i in tqdm(range(runs)):\n",
    "    words = generate_words(best_model, vocab_size, max_len, idx_to_char, char_to_idx, num_per_batch, temperature=1, seed_word=starting_letter)\n",
    "    num_generated.append(num_generated[i] + num_per_batch)\n",
    "    unique_words = set(words).difference(set_of_original_words)\n",
    "    number_unique = len(unique_words)\n",
    "    num_unique.append(number_unique + num_unique[i])\n",
    "    #now modify the set of words to compare with\n",
    "    set_of_original_words = set_of_original_words.union(set(words))\n",
    "    print(num_generated[i], num_unique[i])\n",
    "\n",
    "new_word_count = len(set_of_original_words)\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "05979647-6ac7-4682-9c61-1bbc57ac7f24",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "df4 = pd.read_csv(\"output.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "15904b19",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.705469387755102"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "8642/12250"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "4477ce68",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<Axes: xlabel='Generated'>"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "df4.plot(x=\"Generated\", y=\"Unique\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7d2c2620-0287-4e16-ba66-6f17a8434506",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"original words were {orig_word_count} now we have {new_word_count}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5bdca7db-4711-488d-add7-31c747366e3f",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"most we can get is {max(num_generated)} extra but we got {new_word_count - orig_word_count}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e2e53f7d-fbcf-44c2-a3ab-124476b59bdc",
   "metadata": {},
   "outputs": [],
   "source": [
    "dfs = pd.DataFrame({\"num_gen\" : num_generated, \"num_unique\" : num_unique})\n",
    "dfs.to_excel(\"save_a.xlsx\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "75a4813b-b8e4-4046-98ae-d3fe81b4f70e",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(num_generated, num_unique)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7cc1f968-87e2-4b64-ad71-6040ed517870",
   "metadata": {},
   "outputs": [],
   "source": [
    "filtered_words = [w for w in results if w not in df_check.index]\n",
    "\n",
    "print(\n",
    "    f\"generated {len(results)} new words and out these {len(filtered_words)} are unique\"\n",
    ")\n",
    "print(len(set(results)), len(set(filtered_words)))\n",
    "print(filtered_words)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "id": "66f010bd-0924-4431-8ba7-0c704d765997",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define your generate_word() function as a tf.function\n",
    "@tf.function\n",
    "def generate_words_parallel(\n",
    "    model,\n",
    "    vocab_size,\n",
    "    max_len,\n",
    "    idx_to_char,\n",
    "    char_to_idx,\n",
    "    number=20,\n",
    "    temperature=1,\n",
    "    seed_word=None,\n",
    "):\n",
    "    \"\"\"takes the model and generates words based on softmax output for each character, it will run through the model for\n",
    "    every character in the sequence and randomly sample from the character probabilities (not the max probability) this means\n",
    "    we get variable words each time\"\"\"\n",
    "    print(seed_word)\n",
    "\n",
    "    seed_word_original = seed_word\n",
    "\n",
    "    def generate_word(seed_word, i=0):\n",
    "        def adjust_temperature(predictions, temperature):\n",
    "            predictions = np.log(predictions) / temperature\n",
    "            exp_preds = np.exp(predictions)\n",
    "            adjusted_preds = exp_preds / np.sum(exp_preds)\n",
    "            return adjusted_preds\n",
    "\n",
    "        def next_char(preds):\n",
    "            next_idx = np.random.choice(range(vocab_size), p=preds.ravel())\n",
    "            # next_idx = np.argmax(preds)\n",
    "            char = idx_to_char[next_idx]\n",
    "            return char\n",
    "\n",
    "        def word_to_input(word: str):\n",
    "            \"\"\"takes a string and turns it into a sequence matrix\"\"\"\n",
    "            x_pred = np.zeros((1, max_len, vocab_size))\n",
    "            for t, char in enumerate(word):\n",
    "                x_pred[0, t, char_to_idx[char]] = 1.0\n",
    "            return x_pred\n",
    "\n",
    "        if len(seed_word) == max_len:\n",
    "            return seed_word\n",
    "\n",
    "        x_input = word_to_input(seed_word)\n",
    "        preds = model.predict(x_input, verbose=False)\n",
    "        if temperature != 1:\n",
    "            preds = adjust_temperature(preds, temperature)\n",
    "        char = next_char(preds)\n",
    "        i += 1\n",
    "        # print(seed_word, char, i)\n",
    "\n",
    "        if char == \"\\n\":\n",
    "            return seed_word\n",
    "        else:\n",
    "            return generate_word(seed_word + char, i)\n",
    "\n",
    "    output = []\n",
    "    print(\"generating words\")\n",
    "    for i in tqdm(range(number)):\n",
    "        if seed_word is None:\n",
    "            seed_word = idx_to_char[np.random.choice(np.arange(2, len(char_to_idx)))]\n",
    "        word = generate_word(seed_word)\n",
    "        output.append(word)\n",
    "        seed_word = seed_word_original\n",
    "    return output\n",
    "\n",
    "\n",
    "starting_letters = [\"a\", \"b\", \"c\", \"d\", \"e\"]  # List of starting letters\n",
    "\n",
    "# Create a TensorFlow Dataset from the starting letters\n",
    "dataset = tf.data.Dataset.from_tensor_slices(starting_letters)\n",
    "\n",
    "\n",
    "# Define a function to run generate_word() on the GPU\n",
    "def run_generate_word(letter):\n",
    "    temp_results = generate_words_parallel(\n",
    "        model,\n",
    "        vocab_size,\n",
    "        max_len,\n",
    "        idx_to_char,\n",
    "        char_to_idx,\n",
    "        number=10,\n",
    "        temperature=1,\n",
    "        seed_word=letter,\n",
    "    )\n",
    "    return temp_results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "id": "d7b91fa0-4d47-468c-87dc-aa10a1a84f2e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<tf.Tensor: shape=(), dtype=string, numpy=b'a'>,\n",
       " <tf.Tensor: shape=(), dtype=string, numpy=b'b'>,\n",
       " <tf.Tensor: shape=(), dtype=string, numpy=b'c'>,\n",
       " <tf.Tensor: shape=(), dtype=string, numpy=b'd'>,\n",
       " <tf.Tensor: shape=(), dtype=string, numpy=b'e'>]"
      ]
     },
     "execution_count": 45,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "starting_letters_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "id": "9f4aecee-ae5a-495d-be23-d549ff3a07e6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "a\n",
      "generating words\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|                                                                                     | 0/10 [00:00<?, ?it/s]\n"
     ]
    },
    {
     "ename": "RuntimeError",
     "evalue": "in user code:\n\n    File \"/tmp/ipykernel_519540/1283527189.py\", line 32, in generate_word  *\n        preds = model.predict(x_input, verbose=False)\n    File \"/home/andy/miniconda3/envs/tf/lib/python3.9/site-packages/keras/utils/traceback_utils.py\", line 70, in error_handler  **\n        raise e.with_traceback(filtered_tb) from None\n    File \"/home/andy/miniconda3/envs/tf/lib/python3.9/site-packages/keras/engine/training.py\", line 4045, in _disallow_inside_tf_function\n        raise RuntimeError(error_msg)\n\n    RuntimeError: Detected a call to `Model.predict` inside a `tf.function`. `Model.predict is a high-level endpoint that manages its own `tf.function`. Please move the call to `Model.predict` outside of all enclosing `tf.function`s. Note that you can call a `Model` directly on `Tensor`s inside a `tf.function` like: `model(x)`.\n",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[46], line 10\u001b[0m\n\u001b[1;32m      4\u001b[0m starting_letters_list \u001b[38;5;241m=\u001b[39m [tf\u001b[38;5;241m.\u001b[39mconstant(letter, dtype\u001b[38;5;241m=\u001b[39mtf\u001b[38;5;241m.\u001b[39mstring) \u001b[38;5;28;01mfor\u001b[39;00m letter \u001b[38;5;129;01min\u001b[39;00m starting_letters]\n\u001b[1;32m      6\u001b[0m     \u001b[38;5;66;03m# Define a function to run generate_word() on the GPU with seed_word\u001b[39;00m\n\u001b[1;32m      7\u001b[0m     \n\u001b[1;32m      8\u001b[0m \n\u001b[1;32m      9\u001b[0m     \u001b[38;5;66;03m# Use the map function to run run_generate_word() in parallel\u001b[39;00m\n\u001b[0;32m---> 10\u001b[0m results \u001b[38;5;241m=\u001b[39m [run_generate_word(letter) \u001b[38;5;28;01mfor\u001b[39;00m letter \u001b[38;5;129;01min\u001b[39;00m starting_letters]\n",
      "Cell \u001b[0;32mIn[46], line 10\u001b[0m, in \u001b[0;36m<listcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m      4\u001b[0m starting_letters_list \u001b[38;5;241m=\u001b[39m [tf\u001b[38;5;241m.\u001b[39mconstant(letter, dtype\u001b[38;5;241m=\u001b[39mtf\u001b[38;5;241m.\u001b[39mstring) \u001b[38;5;28;01mfor\u001b[39;00m letter \u001b[38;5;129;01min\u001b[39;00m starting_letters]\n\u001b[1;32m      6\u001b[0m     \u001b[38;5;66;03m# Define a function to run generate_word() on the GPU with seed_word\u001b[39;00m\n\u001b[1;32m      7\u001b[0m     \n\u001b[1;32m      8\u001b[0m \n\u001b[1;32m      9\u001b[0m     \u001b[38;5;66;03m# Use the map function to run run_generate_word() in parallel\u001b[39;00m\n\u001b[0;32m---> 10\u001b[0m results \u001b[38;5;241m=\u001b[39m [\u001b[43mrun_generate_word\u001b[49m\u001b[43m(\u001b[49m\u001b[43mletter\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;28;01mfor\u001b[39;00m letter \u001b[38;5;129;01min\u001b[39;00m starting_letters]\n",
      "Cell \u001b[0;32mIn[44], line 61\u001b[0m, in \u001b[0;36mrun_generate_word\u001b[0;34m(letter)\u001b[0m\n\u001b[1;32m     60\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mrun_generate_word\u001b[39m(letter):\n\u001b[0;32m---> 61\u001b[0m     temp_results \u001b[38;5;241m=\u001b[39m \u001b[43mgenerate_words_parallel\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mvocab_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmax_len\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43midx_to_char\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mchar_to_idx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnumber\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m10\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtemperature\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mseed_word\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mletter\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     62\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m temp_results\n",
      "File \u001b[0;32m~/miniconda3/envs/tf/lib/python3.9/site-packages/tensorflow/python/util/traceback_utils.py:153\u001b[0m, in \u001b[0;36mfilter_traceback.<locals>.error_handler\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    151\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m    152\u001b[0m   filtered_tb \u001b[38;5;241m=\u001b[39m _process_traceback_frames(e\u001b[38;5;241m.\u001b[39m__traceback__)\n\u001b[0;32m--> 153\u001b[0m   \u001b[38;5;28;01mraise\u001b[39;00m e\u001b[38;5;241m.\u001b[39mwith_traceback(filtered_tb) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m    154\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m    155\u001b[0m   \u001b[38;5;28;01mdel\u001b[39;00m filtered_tb\n",
      "File \u001b[0;32m/tmp/__autograph_generated_file7oxbn_mg.py:185\u001b[0m, in \u001b[0;36mouter_factory.<locals>.inner_factory.<locals>.tf__generate_words_parallel\u001b[0;34m(model, vocab_size, max_len, idx_to_char, char_to_idx, number, temperature, seed_word)\u001b[0m\n\u001b[1;32m    183\u001b[0m word \u001b[38;5;241m=\u001b[39m ag__\u001b[38;5;241m.\u001b[39mUndefined(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mword\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m    184\u001b[0m i \u001b[38;5;241m=\u001b[39m ag__\u001b[38;5;241m.\u001b[39mUndefined(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mi\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[0;32m--> 185\u001b[0m ag__\u001b[38;5;241m.\u001b[39mfor_stmt(ag__\u001b[38;5;241m.\u001b[39mconverted_call(ag__\u001b[38;5;241m.\u001b[39mld(tqdm), (ag__\u001b[38;5;241m.\u001b[39mconverted_call(ag__\u001b[38;5;241m.\u001b[39mld(\u001b[38;5;28mrange\u001b[39m), (ag__\u001b[38;5;241m.\u001b[39mld(number),), \u001b[38;5;28;01mNone\u001b[39;00m, fscope),), \u001b[38;5;28;01mNone\u001b[39;00m, fscope), \u001b[38;5;28;01mNone\u001b[39;00m, loop_body_1, get_state_5, set_state_5, (\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mseed_word\u001b[39m\u001b[38;5;124m'\u001b[39m,), {\u001b[38;5;124m'\u001b[39m\u001b[38;5;124miterate_names\u001b[39m\u001b[38;5;124m'\u001b[39m: \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mi\u001b[39m\u001b[38;5;124m'\u001b[39m})\n\u001b[1;32m    186\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m    187\u001b[0m     do_return \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n",
      "File \u001b[0;32m/tmp/__autograph_generated_file7oxbn_mg.py:180\u001b[0m, in \u001b[0;36mouter_factory.<locals>.inner_factory.<locals>.tf__generate_words_parallel.<locals>.loop_body_1\u001b[0;34m(itr_1)\u001b[0m\n\u001b[1;32m    178\u001b[0m     \u001b[38;5;28;01mpass\u001b[39;00m\n\u001b[1;32m    179\u001b[0m ag__\u001b[38;5;241m.\u001b[39mif_stmt(ag__\u001b[38;5;241m.\u001b[39mld(seed_word) \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m, if_body_3, else_body_3, get_state_4, set_state_4, (\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mseed_word\u001b[39m\u001b[38;5;124m'\u001b[39m,), \u001b[38;5;241m1\u001b[39m)\n\u001b[0;32m--> 180\u001b[0m word \u001b[38;5;241m=\u001b[39m \u001b[43mag__\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mconverted_call\u001b[49m\u001b[43m(\u001b[49m\u001b[43mag__\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mld\u001b[49m\u001b[43m(\u001b[49m\u001b[43mgenerate_word\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m(\u001b[49m\u001b[43mag__\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mld\u001b[49m\u001b[43m(\u001b[49m\u001b[43mseed_word\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfscope\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    181\u001b[0m ag__\u001b[38;5;241m.\u001b[39mconverted_call(ag__\u001b[38;5;241m.\u001b[39mld(output)\u001b[38;5;241m.\u001b[39mappend, (ag__\u001b[38;5;241m.\u001b[39mld(word),), \u001b[38;5;28;01mNone\u001b[39;00m, fscope)\n\u001b[1;32m    182\u001b[0m seed_word \u001b[38;5;241m=\u001b[39m ag__\u001b[38;5;241m.\u001b[39mld(seed_word_original)\n",
      "File \u001b[0;32m/tmp/__autograph_generated_file7oxbn_mg.py:149\u001b[0m, in \u001b[0;36mouter_factory.<locals>.inner_factory.<locals>.tf__generate_words_parallel.<locals>.generate_word\u001b[0;34m(seed_word, i)\u001b[0m\n\u001b[1;32m    147\u001b[0m preds \u001b[38;5;241m=\u001b[39m ag__\u001b[38;5;241m.\u001b[39mUndefined(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mpreds\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m    148\u001b[0m char \u001b[38;5;241m=\u001b[39m ag__\u001b[38;5;241m.\u001b[39mUndefined(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mchar\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[0;32m--> 149\u001b[0m ag__\u001b[38;5;241m.\u001b[39mif_stmt(ag__\u001b[38;5;241m.\u001b[39mconverted_call(ag__\u001b[38;5;241m.\u001b[39mld(\u001b[38;5;28mlen\u001b[39m), (ag__\u001b[38;5;241m.\u001b[39mld(seed_word),), \u001b[38;5;28;01mNone\u001b[39;00m, fscope_1) \u001b[38;5;241m==\u001b[39m ag__\u001b[38;5;241m.\u001b[39mld(max_len), if_body_2, else_body_2, get_state_3, set_state_3, (\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mdo_return_1\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mretval__1\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mi\u001b[39m\u001b[38;5;124m'\u001b[39m), \u001b[38;5;241m2\u001b[39m)\n\u001b[1;32m    150\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m fscope_1\u001b[38;5;241m.\u001b[39mret(retval__1, do_return_1)\n",
      "File \u001b[0;32m/tmp/__autograph_generated_file7oxbn_mg.py:100\u001b[0m, in \u001b[0;36mouter_factory.<locals>.inner_factory.<locals>.tf__generate_words_parallel.<locals>.generate_word.<locals>.else_body_2\u001b[0;34m()\u001b[0m\n\u001b[1;32m     98\u001b[0m \u001b[38;5;28;01mnonlocal\u001b[39;00m retval__1, do_return_1, i\n\u001b[1;32m     99\u001b[0m x_input \u001b[38;5;241m=\u001b[39m ag__\u001b[38;5;241m.\u001b[39mconverted_call(ag__\u001b[38;5;241m.\u001b[39mld(word_to_input), (ag__\u001b[38;5;241m.\u001b[39mld(seed_word),), \u001b[38;5;28;01mNone\u001b[39;00m, fscope_1)\n\u001b[0;32m--> 100\u001b[0m preds \u001b[38;5;241m=\u001b[39m \u001b[43mag__\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mconverted_call\u001b[49m\u001b[43m(\u001b[49m\u001b[43mag__\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mld\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpredict\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m(\u001b[49m\u001b[43mag__\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mld\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx_input\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mdict\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mverbose\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfscope_1\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    102\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mget_state_1\u001b[39m():\n\u001b[1;32m    103\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m (preds,)\n",
      "File \u001b[0;32m~/miniconda3/envs/tf/lib/python3.9/site-packages/keras/utils/traceback_utils.py:70\u001b[0m, in \u001b[0;36mfilter_traceback.<locals>.error_handler\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m     67\u001b[0m     filtered_tb \u001b[38;5;241m=\u001b[39m _process_traceback_frames(e\u001b[38;5;241m.\u001b[39m__traceback__)\n\u001b[1;32m     68\u001b[0m     \u001b[38;5;66;03m# To get the full stack trace, call:\u001b[39;00m\n\u001b[1;32m     69\u001b[0m     \u001b[38;5;66;03m# `tf.debugging.disable_traceback_filtering()`\u001b[39;00m\n\u001b[0;32m---> 70\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m e\u001b[38;5;241m.\u001b[39mwith_traceback(filtered_tb) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m     71\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m     72\u001b[0m     \u001b[38;5;28;01mdel\u001b[39;00m filtered_tb\n",
      "File \u001b[0;32m~/miniconda3/envs/tf/lib/python3.9/site-packages/keras/engine/training.py:4045\u001b[0m, in \u001b[0;36m_disallow_inside_tf_function\u001b[0;34m(method_name)\u001b[0m\n\u001b[1;32m   4036\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m tf\u001b[38;5;241m.\u001b[39minside_function():\n\u001b[1;32m   4037\u001b[0m     error_msg \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m   4038\u001b[0m         \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mDetected a call to `Model.\u001b[39m\u001b[38;5;132;01m{method_name}\u001b[39;00m\u001b[38;5;124m` inside a `tf.function`. \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m   4039\u001b[0m         \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m`Model.\u001b[39m\u001b[38;5;132;01m{method_name}\u001b[39;00m\u001b[38;5;124m is a high-level endpoint that manages its \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m   4043\u001b[0m         \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m`model(x)`.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m   4044\u001b[0m     )\u001b[38;5;241m.\u001b[39mformat(method_name\u001b[38;5;241m=\u001b[39mmethod_name)\n\u001b[0;32m-> 4045\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(error_msg)\n",
      "\u001b[0;31mRuntimeError\u001b[0m: in user code:\n\n    File \"/tmp/ipykernel_519540/1283527189.py\", line 32, in generate_word  *\n        preds = model.predict(x_input, verbose=False)\n    File \"/home/andy/miniconda3/envs/tf/lib/python3.9/site-packages/keras/utils/traceback_utils.py\", line 70, in error_handler  **\n        raise e.with_traceback(filtered_tb) from None\n    File \"/home/andy/miniconda3/envs/tf/lib/python3.9/site-packages/keras/engine/training.py\", line 4045, in _disallow_inside_tf_function\n        raise RuntimeError(error_msg)\n\n    RuntimeError: Detected a call to `Model.predict` inside a `tf.function`. `Model.predict is a high-level endpoint that manages its own `tf.function`. Please move the call to `Model.predict` outside of all enclosing `tf.function`s. Note that you can call a `Model` directly on `Tensor`s inside a `tf.function` like: `model(x)`.\n"
     ]
    }
   ],
   "source": [
    "starting_letters = [\"a\", \"b\", \"c\", \"d\", \"e\"]  # List of starting letters\n",
    "\n",
    "# Create a list of strings\n",
    "starting_letters_list = [\n",
    "    tf.constant(letter, dtype=tf.string) for letter in starting_letters\n",
    "]\n",
    "\n",
    "# Define a function to run generate_word() on the GPU with seed_word\n",
    "\n",
    "\n",
    "# Use the map function to run run_generate_word() in parallel\n",
    "results = [run_generate_word(letter) for letter in starting_letters]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a2fbeada-2827-41dc-93b3-d8dde989132c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Process the results as needed\n",
    "for starting_letter, generated_word in zip(starting_letters, results):\n",
    "    print(f\"Starting Letter: {starting_letter}, Generated Word: {generated_word}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "id": "7831a7ea-9779-4bb8-b178-550fe9234503",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_words = pd.DataFrame({\"new_places\": filtered_words})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "id": "fb30cb62-b42b-48cd-9c32-378ba733c3e3",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_words.to_csv(\"new_places.csv\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7f039d33-4d39-4861-a62d-ce0277f6ebce",
   "metadata": {},
   "source": [
    "# For codewords only\n",
    "We can use a 'refine words' function to make each word pronouncable"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 81,
   "id": "87a8b3f9-dfcc-4198-a951-748b6778ce72",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "generating words\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|                                                                                     | 0/50 [00:00<?, ?it/s]/tmp/ipykernel_205769/517760797.py:188: RuntimeWarning: divide by zero encountered in log\n",
      "  predictions = np.log(predictions) / temperature\n",
      "100%|████████████████████████████████████████████████████████████████████████████| 50/50 [00:23<00:00,  2.12it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████| 34/34 [00:18<00:00,  1.84it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "('lozurphy', 'gloze furphy')\n",
      "('apnamo guek', 'sapin morgue')\n",
      "('zoulttihg', 'zygous slitting')\n",
      "('zorrape', 'or raphe')\n",
      "('oceansurs', 'ocean sur')\n",
      "('mrosnefscarter', 'miro newscaster')\n",
      "('pogondewa', 'prog nonda')\n",
      "('inchtigue', 'inch fatigue')\n",
      "('krawaint', 'krawant')\n",
      "('kilming', 'kil mingy')\n",
      "('kraditreant', 'kra distraint')\n",
      "('farlemedeide', 'farl Diomedeidae')\n",
      "('nsa forn', 'ansafor')\n",
      "('prateespleude', 'prate splender')\n",
      "('yankeeoohe', 'yank eyetooth')\n",
      "('jourryide', 'jocuride')\n",
      "('glyrase', 'ugly rasse')\n",
      "('bluesker', 'blues kier')\n",
      "('dirrosse', 'dirl rosser')\n",
      "('cabbalant', 'cab bagplant')\n",
      "('wellere', 'wellwere')\n",
      "('xkeyscorpe', 'oxlike scorper')\n",
      "('eveningemathhion', 'evening emulation')\n",
      "('fallurg', 'fallburg')\n",
      "('py calro', 'pyx claro')\n",
      "('quadresgeat', 'quad resweat')\n",
      "('aorbat', 'asorboat')\n",
      "('unigorvie', 'unio corvine')\n",
      "('epicfard', 'epic fardo')\n",
      "('quigatus', 'quis latus')\n",
      "('queenzy', 'queen zany')\n",
      "('wescrodber', 'wels rober')\n",
      "('wilkatcher', 'wilk patcher')\n",
      "('ostiomon', 'oust timon')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "new_words = generate_words(\n",
    "    best_model,\n",
    "    vocab_size,\n",
    "    max_len,\n",
    "    idx_to_char,\n",
    "    char_to_idx,\n",
    "    10,\n",
    "    temperature=1,\n",
    "    seed_word=None,\n",
    ")\n",
    "# percent_overlap = len(set(new_words).intersection(set(old_words))) / len(set(new_words))\n",
    "# print(f\"overlap is {percent_overlap:.1%}, number of words = {len(set(new_words))}\")\n",
    "new_words = [w for w in new_words if w not in old_words]\n",
    "new_words_refined = refine_words(new_words)\n",
    "all = zip(new_words, new_words_refined)\n",
    "for combo in all:\n",
    "    print(combo)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "224cd22e-9950-4c5d-9f59-cdeed3a4a247",
   "metadata": {},
   "source": [
    "## Clear the keras session to wipe memory"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a4b86665-79a1-4bc2-9def-daec1ce0c8d1",
   "metadata": {},
   "outputs": [],
   "source": [
    "tf.keras.backend.clear_session()\n",
    "batch_size = 2048\n",
    "results_dict = {}\n",
    "loss_dict = {}\n",
    "for size in [64]:\n",
    "    lr = [1e-4]\n",
    "    for learning_rate in lr:\n",
    "        model = build_model_LSTM(\n",
    "            size, x_words, batch_size, max_len, vocab_size, learning_rate=learning_rate\n",
    "        )\n",
    "        loss_values = run_model(model, x_words, y_words, 50, size, type=\"LSTM_words\")\n",
    "        results = generate_words(model, idx_to_char, char_to_idx, max_len, 10)\n",
    "        results_dict[size] = results\n",
    "        print(\"model size \", size, \" results \", results)\n",
    "        loss_dict[size] = loss_values"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}