diff --git "a/fsrs4anki_optimizer.ipynb" "b/fsrs4anki_optimizer.ipynb" new file mode 100644--- /dev/null +++ "b/fsrs4anki_optimizer.ipynb" @@ -0,0 +1,1180 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# FSRS4Anki v3.10.1 Optimizer" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": { + "id": "lurCmW0Jqz3s" + }, + "source": [ + "[![open in colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/open-spaced-repetition/fsrs4anki/blob/v3.10.1/fsrs4anki_optimizer.ipynb)\n", + "\n", + "↑ Click the above button to open the optimizer on Google Colab.\n", + "\n", + "> If you can't see the button and are located in the Chinese Mainland, please use a proxy or VPN." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "wG7bBfGJFbMr" + }, + "source": [ + "Upload your **Anki Deck Package (.apkg)** file or **Anki Collection Package (.colpkg)** file on the `Left sidebar -> Files`, drag and drop your file in the current directory (not the `sample_data` directory). \n", + "\n", + "No need to include media. Need to include scheduling information. \n", + "\n", + "> If you use the latest version of Anki, please check the box `Support older Anki versions (slower/larger files)` when you export.\n", + "\n", + "You can export it via `File -> Export...` or `Ctrl + E` in the main window of Anki.\n", + "\n", + "Then replace the `filename` with yours in the next code cell. And set the `timezone` and `next_day_starts_at` which can be found in your preferences of Anki.\n", + "\n", + "After that, just run all (`Runtime -> Run all` or `Ctrl + F9`) and wait for minutes. You can see the optimal parameters in section **3 Result**. Copy them, replace the parameters in `fsrs4anki_scheduler.js`, and paste them into the custom scheduling of your deck options (require Anki version >= 2.1.55).\n", + "\n", + "**NOTE**: The default output is generated from my review logs. If you find the output is the same as mine, maybe your notebook hasn't run there.\n", + "\n", + "**Contribute to SRS Research**: If you want to share your data with me, please fill this form: https://forms.gle/KaojsBbhMCytaA7h8" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "id": "iqP70_-3EUhi" + }, + "outputs": [], + "source": [ + "# Here are some settings that you need to replace before running this optimizer.\n", + "\n", + "filename = \"collection-2022-09-18@13-21-58.colpkg\"\n", + "# If you upload deck file, replace it with your deck filename. E.g., ALL__Learning.apkg\n", + "# If you upload collection file, replace it with your colpgk filename. E.g., collection-2022-09-18@13-21-58.colpkg\n", + "\n", + "# Replace it with your timezone. I'm in China, so I use Asia/Shanghai.\n", + "# You can find your timezone here: https://gist.github.com/heyalexej/8bf688fd67d7199be4a1682b3eec7568\n", + "timezone = 'Asia/Shanghai'\n", + "\n", + "# Replace it with your Anki's setting in Preferences -> Scheduling.\n", + "next_day_starts_at = 4\n", + "\n", + "# Replace it if you don't want the optimizer to use the review logs before a specific date.\n", + "revlog_start_date = \"2006-10-05\"\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "bLFVNmG2qd06" + }, + "source": [ + "## 1 Build dataset" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "EkzFeKawqgbs" + }, + "source": [ + "### 1.1 Extract Anki collection & deck file" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "KD2js_wEr_Bs", + "outputId": "42653d9e-316e-40bc-bd1d-f3a0e2b246c7" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Extract successfully!\n" + ] + } + ], + "source": [ + "import zipfile\n", + "import sqlite3\n", + "import time\n", + "import tqdm\n", + "import pandas as pd\n", + "import numpy as np\n", + "import os\n", + "from datetime import timedelta, datetime\n", + "import matplotlib.pyplot as plt\n", + "import math\n", + "import sys\n", + "import torch\n", + "from torch import nn\n", + "from sklearn.utils import shuffle\n", + "# Extract the collection file or deck file to get the .anki21 database.\n", + "with zipfile.ZipFile(f'./{filename}', 'r') as zip_ref:\n", + " zip_ref.extractall('./')\n", + " print(\"Extract successfully!\")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "dKpy4VfqGmaL" + }, + "source": [ + "### 1.2 Create time-series feature & analysis\n", + "\n", + "The following code cell will extract the review logs from your Anki collection and preprocess them to a trainset which is saved in `revlog_history.tsv`.\n", + "\n", + "The time-series features are important in optimizing the model's parameters. For more detail, please see my paper: https://www.maimemo.com/paper/\n", + "\n", + "Then it will generate a concise analysis for your review logs. \n", + "\n", + "- The `r_history` is the history of ratings on each review. `3,3,3,1` means that you press `Good, Good, Good, Again`. It only contains the first rating for each card on the review date, i.e., when you press `Again` in review and `Good` in relearning steps 10min later, only `Again` will be recorded.\n", + "- The `avg_interval` is the actual average interval after you rate your cards as the `r_history`. It could be longer than the interval given by Anki's built-in scheduler because you reviewed some overdue cards.\n", + "- The `avg_retention` is the average retention after you press as the `r_history`. `Again` counts as failed recall, and `Hard, Good and Easy` count as successful recall. Retention is the percentage of your successful recall.\n", + "- The `stability` is the estimated memory state variable, which is an approximate interval that leads to 90% retention.\n", + "- The `factor` is `stability / previous stability`.\n", + "- The `group_cnt` is the number of review logs that have the same `r_history`." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "J2IIaY3PDaaG", + "outputId": "607916c9-da95-48dd-fdab-6bd83fbbbb40" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "revlog.csv saved.\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "fec0c445154d4182bbff35e17f98e0ef", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/30711 [00:00 0) &\n", + " (df['id'] >= time.mktime(datetime.strptime(revlog_start_date, \"%Y-%m-%d\").timetuple()) * 1000)].copy()\n", + "df['create_date'] = pd.to_datetime(df['cid'] // 1000, unit='s')\n", + "df['create_date'] = df['create_date'].dt.tz_localize(\n", + " 'UTC').dt.tz_convert(timezone)\n", + "df['review_date'] = pd.to_datetime(df['id'] // 1000, unit='s')\n", + "df['review_date'] = df['review_date'].dt.tz_localize(\n", + " 'UTC').dt.tz_convert(timezone)\n", + "df.drop(df[df['review_date'].dt.year < 2006].index, inplace=True)\n", + "df.sort_values(by=['cid', 'id'], inplace=True, ignore_index=True)\n", + "type_sequence = np.array(df['type'])\n", + "df.to_csv(\"revlog.csv\", index=False)\n", + "print(\"revlog.csv saved.\")\n", + "df = df[(df['type'] == 0) | (df['type'] == 1)].copy()\n", + "df['real_days'] = df['review_date'] - timedelta(hours=next_day_starts_at)\n", + "df['real_days'] = pd.DatetimeIndex(df['real_days'].dt.floor('D')).to_julian_date()\n", + "df.drop_duplicates(['cid', 'real_days'], keep='first', inplace=True)\n", + "df['delta_t'] = df.real_days.diff()\n", + "df.dropna(inplace=True)\n", + "df['delta_t'] = df['delta_t'].astype(dtype=int)\n", + "df['i'] = 1\n", + "df['r_history'] = \"\"\n", + "df['t_history'] = \"\"\n", + "col_idx = {key: i for i, key in enumerate(df.columns)}\n", + "\n", + "\n", + "# code from https://github.com/L-M-Sherlock/anki_revlog_analysis/blob/main/revlog_analysis.py\n", + "def get_feature(x):\n", + " for idx, log in enumerate(x.itertuples()):\n", + " if idx == 0:\n", + " x.iloc[idx, col_idx['delta_t']] = 0\n", + " if idx == x.shape[0] - 1:\n", + " break\n", + " x.iloc[idx + 1, col_idx['i']] = x.iloc[idx, col_idx['i']] + 1\n", + " x.iloc[idx + 1, col_idx['t_history']] = f\"{x.iloc[idx, col_idx['t_history']]},{x.iloc[idx, col_idx['delta_t']]}\"\n", + " x.iloc[idx + 1, col_idx['r_history']] = f\"{x.iloc[idx, col_idx['r_history']]},{x.iloc[idx, col_idx['r']]}\"\n", + " return x\n", + "\n", + "tqdm.notebook.tqdm.pandas()\n", + "df = df.groupby('cid', as_index=False).progress_apply(get_feature)\n", + "df[\"t_history\"] = df[\"t_history\"].map(lambda x: x[1:] if len(x) > 1 else x)\n", + "df[\"r_history\"] = df[\"r_history\"].map(lambda x: x[1:] if len(x) > 1 else x)\n", + "df.to_csv('revlog_history.tsv', sep=\"\\t\", index=False)\n", + "print(\"Trainset saved.\")\n", + "\n", + "def cal_retention(group: pd.DataFrame) -> pd.DataFrame:\n", + " group['retention'] = round(group['r'].map(lambda x: {1: 0, 2: 1, 3: 1, 4: 1}[x]).mean(), 4)\n", + " group['total_cnt'] = group.shape[0]\n", + " return group\n", + "\n", + "df = df.groupby(by=['r_history', 'delta_t']).progress_apply(cal_retention)\n", + "print(\"Retention calculated.\")\n", + "df = df.drop(columns=['id', 'cid', 'usn', 'ivl', 'last_lvl', 'factor', 'time', 'type', 'create_date', 'review_date', 'real_days', 'r', 't_history'])\n", + "df.drop_duplicates(inplace=True)\n", + "df = df[(df['retention'] < 1) & (df['retention'] > 0)]\n", + "\n", + "def cal_stability(group: pd.DataFrame) -> pd.DataFrame:\n", + " if group['i'].values[0] > 1:\n", + " r_ivl_cnt = sum(group['delta_t'] * group['retention'].map(np.log) * pow(group['total_cnt'], 2))\n", + " ivl_ivl_cnt = sum(group['delta_t'].map(lambda x: x ** 2) * pow(group['total_cnt'], 2))\n", + " group['stability'] = round(np.log(0.9) / (r_ivl_cnt / ivl_ivl_cnt), 1)\n", + " else:\n", + " group['stability'] = 0.0\n", + " group['group_cnt'] = sum(group['total_cnt'])\n", + " group['avg_retention'] = round(sum(group['retention'] * pow(group['total_cnt'], 2)) / sum(pow(group['total_cnt'], 2)), 3)\n", + " group['avg_interval'] = round(sum(group['delta_t'] * pow(group['total_cnt'], 2)) / sum(pow(group['total_cnt'], 2)), 1)\n", + " del group['total_cnt']\n", + " del group['retention']\n", + " del group['delta_t']\n", + " return group\n", + "\n", + "df = df.groupby(by=['r_history']).progress_apply(cal_stability)\n", + "print(\"Stability calculated.\")\n", + "df.reset_index(drop = True, inplace = True)\n", + "df.drop_duplicates(inplace=True)\n", + "df.sort_values(by=['r_history'], inplace=True, ignore_index=True)\n", + "\n", + "if df.shape[0] > 0:\n", + " for idx in tqdm.notebook.tqdm(df.index):\n", + " item = df.loc[idx]\n", + " index = df[(df['i'] == item['i'] + 1) & (df['r_history'].str.startswith(item['r_history']))].index\n", + " df.loc[index, 'last_stability'] = item['stability']\n", + " df['factor'] = round(df['stability'] / df['last_stability'], 2)\n", + " df = df[(df['i'] >= 2) & (df['group_cnt'] >= 100)]\n", + " df['last_recall'] = df['r_history'].map(lambda x: x[-1])\n", + " df = df[df.groupby(['i', 'r_history'])['group_cnt'].transform(max) == df['group_cnt']]\n", + " df.to_csv('./stability_for_analysis.tsv', sep='\\t', index=None)\n", + " print(\"1:again, 2:hard, 3:good, 4:easy\\n\")\n", + " print(df[df['r_history'].str.contains(r'^[1-4][^124]*$', regex=True)][['r_history', 'avg_interval', 'avg_retention', 'stability', 'factor', 'group_cnt']].to_string(index=False))\n", + " print(\"Analysis saved!\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "k_SgzC-auWmu" + }, + "source": [ + "## 2 Optimize parameter" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "WrfBJjqCHEwJ" + }, + "source": [ + "### 2.1 Define the model\n", + "\n", + "FSRS is a time-series model for predicting memory states." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "id": "tdYp3GMLhTYm" + }, + "outputs": [], + "source": [ + "init_w = [1, 1, 5, -0.5, -0.5, 0.2, 1.4, -0.02, 0.8, 2, -0.2, 0.5, 1]\n", + "\n", + "\n", + "class FSRS(nn.Module):\n", + " def __init__(self, w):\n", + " super(FSRS, self).__init__()\n", + " self.w = nn.Parameter(torch.FloatTensor(w))\n", + " self.zero = torch.FloatTensor([0.0])\n", + "\n", + " def forward(self, x, s, d):\n", + " '''\n", + " :param x: [review interval, review response]\n", + " :param s: stability\n", + " :param d: difficulty\n", + " :return:\n", + " '''\n", + " if torch.equal(s, self.zero):\n", + " # first learn, init memory states\n", + " new_s = self.w[0] + self.w[1] * (x[1] - 1)\n", + " new_d = self.w[2] + self.w[3] * (x[1] - 3)\n", + " new_d = new_d.clamp(1, 10)\n", + " else:\n", + " r = torch.exp(np.log(0.9) * x[0] / s)\n", + " new_d = d + self.w[4] * (x[1] - 3)\n", + " new_d = self.mean_reversion(self.w[2], new_d)\n", + " new_d = new_d.clamp(1, 10)\n", + " # recall\n", + " if x[1] > 1:\n", + " new_s = s * (1 + torch.exp(self.w[6]) *\n", + " (11 - new_d) *\n", + " torch.pow(s, self.w[7]) *\n", + " (torch.exp((1 - r) * self.w[8]) - 1))\n", + " # forget\n", + " else:\n", + " new_s = self.w[9] * torch.pow(new_d, self.w[10]) * torch.pow(\n", + " s, self.w[11]) * torch.exp((1 - r) * self.w[12])\n", + " return new_s, new_d\n", + "\n", + " def loss(self, s, t, r):\n", + " return - (r * np.log(0.9) * t / s + (1 - r) * torch.log(1 - torch.exp(np.log(0.9) * t / s)))\n", + "\n", + " def mean_reversion(self, init, current):\n", + " return self.w[5] * init + (1-self.w[5]) * current\n", + "\n", + "\n", + "class WeightClipper(object):\n", + " def __init__(self, frequency=1):\n", + " self.frequency = frequency\n", + "\n", + " def __call__(self, module):\n", + " if hasattr(module, 'w'):\n", + " w = module.w.data\n", + " w[0] = w[0].clamp(0.1, 10) # initStability\n", + " w[1] = w[1].clamp(0.1, 5) # initStabilityRatingFactor\n", + " w[2] = w[2].clamp(1, 10) # initDifficulty\n", + " w[3] = w[3].clamp(-5, -0.1) # initDifficultyRatingFactor\n", + " w[4] = w[4].clamp(-5, -0.1) # updateDifficultyRatingFactor\n", + " w[5] = w[5].clamp(0, 0.5) # difficultyMeanReversionFactor\n", + " w[6] = w[6].clamp(0, 2) # recallFactor\n", + " w[7] = w[7].clamp(-0.2, -0.01) # recallStabilityDecay\n", + " w[8] = w[8].clamp(0.01, 1.5) # recallRetrievabilityFactor\n", + " w[9] = w[9].clamp(0.5, 5) # forgetFactor\n", + " w[10] = w[10].clamp(-2, -0.01) # forgetDifficultyDecay\n", + " w[11] = w[11].clamp(0.01, 0.9) # forgetStabilityDecay\n", + " w[12] = w[12].clamp(0.01, 2) # forgetRetrievabilityFactor\n", + " module.w.data = w\n", + "\n", + "def lineToTensor(line):\n", + " ivl = line[0].split(',')\n", + " response = line[1].split(',')\n", + " tensor = torch.zeros(len(response), 2)\n", + " for li, response in enumerate(response):\n", + " tensor[li][0] = int(ivl[li])\n", + " tensor[li][1] = int(response)\n", + " return tensor\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "8E1dYfgQLZAC" + }, + "source": [ + "### 2.2 Train the model\n", + "\n", + "The `revlog_history.tsv` generated before will be used for training the FSRS model." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "Jht0gneShowU", + "outputId": "aaa72b79-b454-483b-d746-df1a353b2c8f" + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "9fcd540afa254d6082b0f51400d633a4", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/225934 [00:00 1) & (dataset['delta_t'] > 0) & (dataset['t_history'].str.count(',0') == 0)]\n", + "dataset['tensor'] = dataset.progress_apply(lambda x: lineToTensor(list(zip([x['t_history']], [x['r_history']]))[0]), axis=1)\n", + "print(\"Tensorized!\")\n", + "\n", + "pre_train_set = dataset[dataset['i'] == 2]\n", + "# pretrain\n", + "epoch_len = len(pre_train_set)\n", + "n_epoch = 1\n", + "pbar = tqdm.notebook.tqdm(desc=\"pre-train\", colour=\"red\", total=epoch_len*n_epoch)\n", + "\n", + "for k in range(n_epoch):\n", + " for i, (_, row) in enumerate(shuffle(pre_train_set, random_state=2022 + k).iterrows()):\n", + " model.train()\n", + " optimizer.zero_grad()\n", + " output_t = [(model.zero, model.zero)]\n", + " for input_t in row['tensor']:\n", + " output_t.append(model(input_t, *output_t[-1]))\n", + " loss = model.loss(output_t[-1][0], row['delta_t'],\n", + " {1: 0, 2: 1, 3: 1, 4: 1}[row['r']])\n", + " if np.isnan(loss.data.item()):\n", + " # Exception Case\n", + " print(row, output_t)\n", + " raise Exception('error case')\n", + " loss.backward()\n", + " optimizer.step()\n", + " model.apply(clipper)\n", + " pbar.update()\n", + "pbar.close()\n", + "for name, param in model.named_parameters():\n", + " print(f\"{name}: {list(map(lambda x: round(float(x), 4),param))}\")\n", + "\n", + "train_set = dataset[dataset['i'] > 2]\n", + "epoch_len = len(train_set)\n", + "n_epoch = 1\n", + "print_len = max(epoch_len*n_epoch // 10, 1)\n", + "pbar = tqdm.notebook.tqdm(desc=\"train\", colour=\"red\", total=epoch_len*n_epoch)\n", + "\n", + "for k in range(n_epoch):\n", + " for i, (_, row) in enumerate(shuffle(train_set, random_state=2022 + k).iterrows()):\n", + " model.train()\n", + " optimizer.zero_grad()\n", + " output_t = [(model.zero, model.zero)]\n", + " for input_t in row['tensor']:\n", + " output_t.append(model(input_t, *output_t[-1]))\n", + " loss = model.loss(output_t[-1][0], row['delta_t'],\n", + " {1: 0, 2: 1, 3: 1, 4: 1}[row['r']])\n", + " if np.isnan(loss.data.item()):\n", + " # Exception Case\n", + " print(row, output_t)\n", + " raise Exception('error case')\n", + " loss.backward()\n", + " for param in model.parameters():\n", + " param.grad[:2] = torch.zeros(2)\n", + " optimizer.step()\n", + " model.apply(clipper)\n", + " pbar.update()\n", + "\n", + " if (k * epoch_len + i) % print_len == 0:\n", + " print(f\"iteration: {k * epoch_len + i + 1}\")\n", + " for name, param in model.named_parameters():\n", + " print(f\"{name}: {list(map(lambda x: round(float(x), 4),param))}\")\n", + "pbar.close()\n", + "\n", + "w = list(map(lambda x: round(float(x), 4), dict(model.named_parameters())['w'].data))\n", + "\n", + "print(\"\\nTraining finished!\")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "BZ4S2l7BWfzr" + }, + "source": [ + "## 3 Result\n", + "\n", + "Copy the optimal parameters for FSRS for you in the output of next code cell after running." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "NTnPSDA2QpUu", + "outputId": "49f487b9-69a7-4e96-b35a-7e027f478fbd" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "var w = [1.014, 2.2933, 4.9588, -1.1608, -0.9954, 0.0234, 1.3923, -0.0484, 0.7363, 1.6937, -0.4708, 0.6032, 0.9762];\n" + ] + } + ], + "source": [ + "print(f\"var w = {w};\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "I_zsoDyTaTrT" + }, + "source": [ + "You can see the memory states and intervals generated by FSRS as if you press the good in each review at the due date scheduled by FSRS." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "iws4rtP1WKBT", + "outputId": "890d0287-1a17-4c59-fbbf-ee54d79cd383" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1:again, 2:hard, 3:good, 4:easy\n", + "\n", + "first rating: 1\n", + "rating history: 1,3,3,3,3,3,3,3,3,3,3\n", + "interval history: 0,1,2,4,9,19,39,79,159,317,624\n", + "difficulty history: 0,7.3,7.2,7.2,7.1,7.1,7.0,7.0,6.9,6.9,6.8\n", + "\n", + "first rating: 2\n", + "rating history: 2,3,3,3,3,3,3,3,3,3,3\n", + "interval history: 0,3,8,19,44,100,223,489,1052,2226,4631\n", + "difficulty history: 0,6.1,6.1,6.1,6.0,6.0,6.0,6.0,5.9,5.9,5.9\n", + "\n", + "first rating: 3\n", + "rating history: 3,3,3,3,3,3,3,3,3,3,3\n", + "interval history: 0,6,16,42,107,265,641,1512,3483,7842,17280\n", + "difficulty history: 0,5.0,5.0,5.0,5.0,5.0,5.0,5.0,5.0,5.0,5.0\n", + "\n", + "first rating: 4\n", + "rating history: 4,3,3,3,3,3,3,3,3,3,3\n", + "interval history: 0,8,24,69,192,517,1348,3409,8376,20022,46625\n", + "difficulty history: 0,3.8,3.8,3.9,3.9,3.9,3.9,4.0,4.0,4.0,4.0\n", + "\n" + ] + } + ], + "source": [ + "requestRetention = 0.9 # recommended setting: 0.8 ~ 0.9\n", + "\n", + "\n", + "class Collection:\n", + " def __init__(self, w):\n", + " self.model = FSRS(w)\n", + "\n", + " def states(self, t_history, r_history):\n", + " with torch.no_grad():\n", + " line_tensor = lineToTensor(list(zip([t_history], [r_history]))[0])\n", + " output_t = [(self.model.zero, self.model.zero)]\n", + " for input_t in line_tensor:\n", + " output_t.append(self.model(input_t, *output_t[-1]))\n", + " return output_t[-1]\n", + "\n", + "\n", + "my_collection = Collection(w)\n", + "print(\"1:again, 2:hard, 3:good, 4:easy\\n\")\n", + "for first_rating in (1,2,3,4):\n", + " print(f'first rating: {first_rating}')\n", + " t_history = \"0\"\n", + " d_history = \"0\"\n", + " r_history = f\"{first_rating}\" # the first rating of the new card\n", + " # print(\"stability, difficulty, lapses\")\n", + " for i in range(10):\n", + " states = my_collection.states(t_history, r_history)\n", + " # print('{0:9.2f} {1:11.2f} {2:7.0f}'.format(\n", + " # *list(map(lambda x: round(float(x), 4), states))))\n", + " next_t = max(round(float(np.log(requestRetention)/np.log(0.9) * states[0])), 1)\n", + " difficulty = round(float(states[1]), 1)\n", + " t_history += f',{int(next_t)}'\n", + " d_history += f',{difficulty}'\n", + " r_history += f\",3\"\n", + " print(f\"rating history: {r_history}\")\n", + " print(f\"interval history: {t_history}\")\n", + " print(f\"difficulty history: {d_history}\")\n", + " print('')\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You can change the `test_rating_sequence` to see the scheduling intervals in different ratings." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(tensor(5.6006), tensor(4.9588))\n", + "(tensor(15.8420), tensor(4.9588))\n", + "(tensor(41.8382), tensor(4.9588))\n", + "(tensor(106.9516), tensor(4.9588))\n", + "(tensor(265.4778), tensor(4.9588))\n", + "(tensor(21.7929), tensor(6.9030))\n", + "(tensor(4.3071), tensor(8.8017))\n", + "(tensor(6.9323), tensor(8.7118))\n", + "(tensor(11.5881), tensor(8.6240))\n", + "(tensor(19.6509), tensor(8.5382))\n", + "(tensor(33.1994), tensor(8.4545))\n", + "(tensor(55.7034), tensor(8.3727))\n", + "rating history: 3,3,3,3,3,1,1,3,3,3,3,3\n", + "interval history: 0,6,16,42,107,265,22,4,7,12,20,33,56\n", + "difficulty history: 0,5.0,5.0,5.0,5.0,5.0,6.9,8.8,8.7,8.6,8.5,8.5,8.4\n" + ] + } + ], + "source": [ + "test_rating_sequence = \"3,3,3,3,3,1,1,3,3,3,3,3\"\n", + "requestRetention = 0.9 # recommended setting: 0.8 ~ 0.9\n", + "easyBonus = 1.3\n", + "hardInterval = 1.2\n", + "\n", + "t_history = \"0\"\n", + "d_history = \"0\"\n", + "for i in range(len(test_rating_sequence.split(','))):\n", + " rating = test_rating_sequence[2*i]\n", + " last_t = int(t_history.split(',')[-1])\n", + " r_history = test_rating_sequence[:2*i+1]\n", + " states = my_collection.states(t_history, r_history)\n", + " print(states)\n", + " next_t = max(1,round(float(np.log(requestRetention)/np.log(0.9) * states[0])))\n", + " if rating == '4':\n", + " next_t = round(next_t * easyBonus)\n", + " elif rating == '2':\n", + " next_t = round(last_t * hardInterval)\n", + " t_history += f',{int(next_t)}'\n", + " difficulty = round(float(states[1]), 1)\n", + " d_history += f',{difficulty}'\n", + "print(f\"rating history: {test_rating_sequence}\")\n", + "print(f\"interval history: {t_history}\")\n", + "print(f\"difficulty history: {d_history}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Predict memory states for each review and save them in `prediction.tsv`.\n", + "\n", + "Meanwhile, it will count the distribution of difficulty." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "0bc168078d174ab88d48a1a24056bd7f", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/119670 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "base = 1.01\n", + "index_len = 800\n", + "index_offset = 150\n", + "d_range = 10\n", + "d_offset = 1\n", + "r_repetitions = 1\n", + "f_repetitions = 2.3\n", + "max_repetitions = 200000\n", + "\n", + "type_block = dict()\n", + "type_count = dict()\n", + "last_t = type_sequence[0]\n", + "type_block[last_t] = 1\n", + "type_count[last_t] = 1\n", + "for t in type_sequence[1:]:\n", + " type_count[t] = type_count.setdefault(t, 0) + 1\n", + " if t != last_t:\n", + " type_block[t] = type_block.setdefault(t, 0) + 1\n", + " last_t = t\n", + "if 2 in type_count and 2 in type_block:\n", + " f_repetitions = round(type_count[2]/type_block[2] + 1, 1)\n", + "\n", + "def stability2index(stability):\n", + " return int(round(np.log(stability) / np.log(base)) + index_offset)\n", + "\n", + "def init_stability(d):\n", + " return max(((d - w[2]) / w[3] + 2) * w[1] + w[0], np.power(base, -index_offset))\n", + "\n", + "def cal_next_recall_stability(s, r, d, response):\n", + " if response == 1:\n", + " return s * (1 + np.exp(w[6]) * (11 - d) * np.power(s, w[7]) * (np.exp((1 - r) * w[8]) - 1))\n", + " else:\n", + " return w[9] * np.power(d, w[10]) * np.power(s, w[11]) * np.exp((1 - r) * w[12])\n", + "\n", + "\n", + "stability_list = np.array([np.power(base, i - index_offset) for i in range(index_len)])\n", + "print(f\"terminal stability: {stability_list.max(): .2f}\")\n", + "df = pd.DataFrame(columns=[\"retention\", \"difficulty\", \"repetitions\"])\n", + "\n", + "for percentage in tqdm.notebook.tqdm(range(96, 70, -2)):\n", + " recall = percentage / 100\n", + " repetitions_list = np.zeros((d_range, index_len))\n", + " repetitions_list[:,:-1] = max_repetitions\n", + " for d in range(d_range, 0, -1):\n", + " s0 = init_stability(d)\n", + " s0_index = stability2index(s0)\n", + " diff = max_repetitions\n", + " while diff > 0.1:\n", + " s0_repetitions = repetitions_list[d - 1][s0_index]\n", + " for s_index in range(index_len - 2, -1, -1):\n", + " stability = stability_list[s_index];\n", + " interval = max(1, round(stability * np.log(recall) / np.log(0.9)))\n", + " p_recall = np.power(0.9, interval / stability)\n", + " recall_s = cal_next_recall_stability(stability, p_recall, d, 1)\n", + " forget_d = min(d + d_offset, 10)\n", + " forget_s = cal_next_recall_stability(stability, p_recall, forget_d, 0)\n", + " recall_s_index = min(stability2index(recall_s), index_len - 1)\n", + " forget_s_index = min(max(stability2index(forget_s), 0), index_len - 1)\n", + " recall_repetitions = repetitions_list[d - 1][recall_s_index] + r_repetitions\n", + " forget_repetitions = repetitions_list[forget_d - 1][forget_s_index] + f_repetitions\n", + " exp_repetitions = p_recall * recall_repetitions + (1.0 - p_recall) * forget_repetitions\n", + " if exp_repetitions < repetitions_list[d - 1][s_index]:\n", + " repetitions_list[d - 1][s_index] = exp_repetitions\n", + " diff = s0_repetitions - repetitions_list[d - 1][s0_index]\n", + " df.loc[0 if pd.isnull(df.index.max()) else df.index.max() + 1] = [recall, d, s0_repetitions]\n", + "\n", + "df.sort_values(by=[\"difficulty\", \"retention\"], inplace=True)\n", + "df.to_csv(\"./expected_repetitions.csv\", index=False)\n", + "print(\"expected_repetitions.csv saved.\")\n", + "\n", + "optimal_retention_list = np.zeros(10)\n", + "for d in range(1, d_range+1):\n", + " retention = df[df[\"difficulty\"] == d][\"retention\"]\n", + " repetitions = df[df[\"difficulty\"] == d][\"repetitions\"]\n", + " optimal_retention = retention.iat[repetitions.argmin()]\n", + " optimal_retention_list[d-1] = optimal_retention\n", + " plt.plot(retention, repetitions, label=f\"d={d}, r={optimal_retention}\")\n", + "print(f\"\\n-----suggested retention: {np.inner(difficulty_distribution_padding, optimal_retention_list):.2f}-----\")\n", + "plt.ylabel(\"expected repetitions\")\n", + "plt.xlabel(\"retention\")\n", + "plt.legend()\n", + "plt.grid()\n", + "plt.semilogy()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Evaluate the model with the log loss. It will compare the log loss between initial model and trained model." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "6dd615097589494b9829477848785fcb", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/225934 [00:00