{ "cells": [ { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# FSRS4Anki v3.13.0 Optimizer" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": { "id": "lurCmW0Jqz3s" }, "source": [ "[![open in colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/open-spaced-repetition/fsrs4anki/blob/v3.13.0/fsrs4anki_optimizer.ipynb)\n", "\n", "↑ Click the above button to open the optimizer on Google Colab.\n", "\n", "> If you can't see the button and are located in the Chinese Mainland, please use a proxy or VPN." ] }, { "attachments": {}, "cell_type": "markdown", "metadata": { "id": "wG7bBfGJFbMr" }, "source": [ "Upload your **Anki Deck Package (.apkg)** file or **Anki Collection Package (.colpkg)** file on the `Left sidebar -> Files`, drag and drop your file in the current directory (not the `sample_data` directory). \n", "\n", "No need to include media. Need to include scheduling information. \n", "\n", "> If you use the latest version of Anki, please check the box `Support older Anki versions (slower/larger files)` when you export.\n", "\n", "You can export it via `File -> Export...` or `Ctrl + E` in the main window of Anki.\n", "\n", "Then replace the `filename` with yours in the next code cell. And set the `timezone` and `next_day_starts_at` which can be found in your preferences of Anki.\n", "\n", "After that, just run all (`Runtime -> Run all` or `Ctrl + F9`) and wait for minutes. You can see the optimal parameters in section **2.3 Result**. Copy them, replace the parameters in `fsrs4anki_scheduler.js`, and paste them into the custom scheduling of your deck options (require Anki version >= 2.1.55).\n", "\n", "**NOTE**: The default output is generated from my review logs. If you find the output is the same as mine, maybe your notebook hasn't run there.\n", "\n", "**Contribute to SRS Research**: If you want to share your data with me, please fill this form: https://forms.gle/KaojsBbhMCytaA7h8" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "id": "iqP70_-3EUhi" }, "outputs": [], "source": [ "# Here are some settings that you need to replace before running this optimizer.\n", "\n", "filename = \"collection-2022-09-18@13-21-58.colpkg\"\n", "# If you upload deck file, replace it with your deck filename. E.g., ALL__Learning.apkg\n", "# If you upload collection file, replace it with your colpgk filename. E.g., collection-2022-09-18@13-21-58.colpkg\n", "\n", "# Replace it with your timezone. I'm in China, so I use Asia/Shanghai.\n", "# You can find your timezone here: https://gist.github.com/heyalexej/8bf688fd67d7199be4a1682b3eec7568\n", "timezone = 'Asia/Shanghai'\n", "\n", "# Replace it with your Anki's setting in Preferences -> Scheduling.\n", "next_day_starts_at = 4\n", "\n", "# Replace it if you don't want the optimizer to use the review logs before a specific date.\n", "revlog_start_date = \"2006-10-05\"\n" ] }, { "cell_type": "markdown", "metadata": { "id": "bLFVNmG2qd06" }, "source": [ "## 1 Build dataset" ] }, { "cell_type": "markdown", "metadata": { "id": "EkzFeKawqgbs" }, "source": [ "### 1.1 Extract Anki collection & deck file" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "KD2js_wEr_Bs", "outputId": "42653d9e-316e-40bc-bd1d-f3a0e2b246c7" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Extract successfully!\n" ] } ], "source": [ "import zipfile\n", "import sqlite3\n", "import time\n", "import tqdm\n", "import pandas as pd\n", "import numpy as np\n", "import os\n", "from datetime import timedelta, datetime\n", "import matplotlib.pyplot as plt\n", "import math\n", "import sys\n", "import torch\n", "from torch import nn\n", "from sklearn.utils import shuffle\n", "# Extract the collection file or deck file to get the .anki21 database.\n", "with zipfile.ZipFile(f'./{filename}', 'r') as zip_ref:\n", " zip_ref.extractall('./')\n", " print(\"Extract successfully!\")\n" ] }, { "cell_type": "markdown", "metadata": { "id": "dKpy4VfqGmaL" }, "source": [ "### 1.2 Create time-series feature & analysis\n", "\n", "The following code cell will extract the review logs from your Anki collection and preprocess them to a trainset which is saved in `revlog_history.tsv`.\n", "\n", "The time-series features are important in optimizing the model's parameters. For more detail, please see my paper: https://www.maimemo.com/paper/\n", "\n", "Then it will generate a concise analysis for your review logs. \n", "\n", "- The `r_history` is the history of ratings on each review. `3,3,3,1` means that you press `Good, Good, Good, Again`. It only contains the first rating for each card on the review date, i.e., when you press `Again` in review and `Good` in relearning steps 10min later, only `Again` will be recorded.\n", "- The `avg_interval` is the actual average interval after you rate your cards as the `r_history`. It could be longer than the interval given by Anki's built-in scheduler because you reviewed some overdue cards.\n", "- The `avg_retention` is the average retention after you press as the `r_history`. `Again` counts as failed recall, and `Hard, Good and Easy` count as successful recall. Retention is the percentage of your successful recall.\n", "- The `stability` is the estimated memory state variable, which is an approximate interval that leads to 90% retention.\n", "- The `factor` is `stability / previous stability`.\n", "- The `group_cnt` is the number of review logs that have the same `r_history`." ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "J2IIaY3PDaaG", "outputId": "607916c9-da95-48dd-fdab-6bd83fbbbb40" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "revlog.csv saved.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "191e41fc14f34c1789b34cf09aaf92cb", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/30711 [00:00 0) &\n", " (df['id'] >= time.mktime(datetime.strptime(revlog_start_date, \"%Y-%m-%d\").timetuple()) * 1000)].copy()\n", "df['create_date'] = pd.to_datetime(df['cid'] // 1000, unit='s')\n", "df['create_date'] = df['create_date'].dt.tz_localize(\n", " 'UTC').dt.tz_convert(timezone)\n", "df['review_date'] = pd.to_datetime(df['id'] // 1000, unit='s')\n", "df['review_date'] = df['review_date'].dt.tz_localize(\n", " 'UTC').dt.tz_convert(timezone)\n", "df.drop(df[df['review_date'].dt.year < 2006].index, inplace=True)\n", "df.sort_values(by=['cid', 'id'], inplace=True, ignore_index=True)\n", "type_sequence = np.array(df['type'])\n", "time_sequence = np.array(df['time'])\n", "df.to_csv(\"revlog.csv\", index=False)\n", "print(\"revlog.csv saved.\")\n", "df = df[(df['type'] == 0) | (df['type'] == 1)].copy()\n", "df['real_days'] = df['review_date'] - timedelta(hours=next_day_starts_at)\n", "df['real_days'] = pd.DatetimeIndex(df['real_days'].dt.floor('D')).to_julian_date()\n", "df.drop_duplicates(['cid', 'real_days'], keep='first', inplace=True)\n", "df['delta_t'] = df.real_days.diff()\n", "df.dropna(inplace=True)\n", "df['delta_t'] = df['delta_t'].astype(dtype=int)\n", "df['i'] = 1\n", "df['r_history'] = \"\"\n", "df['t_history'] = \"\"\n", "col_idx = {key: i for i, key in enumerate(df.columns)}\n", "\n", "\n", "# code from https://github.com/L-M-Sherlock/anki_revlog_analysis/blob/main/revlog_analysis.py\n", "def get_feature(x):\n", " for idx, log in enumerate(x.itertuples()):\n", " if idx == 0:\n", " x.iloc[idx, col_idx['delta_t']] = 0\n", " if idx == x.shape[0] - 1:\n", " break\n", " x.iloc[idx + 1, col_idx['i']] = x.iloc[idx, col_idx['i']] + 1\n", " x.iloc[idx + 1, col_idx['t_history']] = f\"{x.iloc[idx, col_idx['t_history']]},{x.iloc[idx, col_idx['delta_t']]}\"\n", " x.iloc[idx + 1, col_idx['r_history']] = f\"{x.iloc[idx, col_idx['r_history']]},{x.iloc[idx, col_idx['r']]}\"\n", " return x\n", "\n", "tqdm.notebook.tqdm.pandas()\n", "df = df.groupby('cid', as_index=False).progress_apply(get_feature)\n", "df[\"t_history\"] = df[\"t_history\"].map(lambda x: x[1:] if len(x) > 1 else x)\n", "df[\"r_history\"] = df[\"r_history\"].map(lambda x: x[1:] if len(x) > 1 else x)\n", "df.to_csv('revlog_history.tsv', sep=\"\\t\", index=False)\n", "print(\"Trainset saved.\")\n", "\n", "def cal_retention(group: pd.DataFrame) -> pd.DataFrame:\n", " group['retention'] = round(group['r'].map(lambda x: {1: 0, 2: 1, 3: 1, 4: 1}[x]).mean(), 4)\n", " group['total_cnt'] = group.shape[0]\n", " return group\n", "\n", "df = df.groupby(by=['r_history', 'delta_t']).progress_apply(cal_retention)\n", "print(\"Retention calculated.\")\n", "df = df.drop(columns=['id', 'cid', 'usn', 'ivl', 'last_lvl', 'factor', 'time', 'type', 'create_date', 'review_date', 'real_days', 'r', 't_history'])\n", "df.drop_duplicates(inplace=True)\n", "df = df[(df['retention'] < 1) & (df['retention'] > 0)]\n", "\n", "def cal_stability(group: pd.DataFrame) -> pd.DataFrame:\n", " if group['i'].values[0] > 1:\n", " r_ivl_cnt = sum(group['delta_t'] * group['retention'].map(np.log) * pow(group['total_cnt'], 2))\n", " ivl_ivl_cnt = sum(group['delta_t'].map(lambda x: x ** 2) * pow(group['total_cnt'], 2))\n", " group['stability'] = round(np.log(0.9) / (r_ivl_cnt / ivl_ivl_cnt), 1)\n", " else:\n", " group['stability'] = 0.0\n", " group['group_cnt'] = sum(group['total_cnt'])\n", " group['avg_retention'] = round(sum(group['retention'] * pow(group['total_cnt'], 2)) / sum(pow(group['total_cnt'], 2)), 3)\n", " group['avg_interval'] = round(sum(group['delta_t'] * pow(group['total_cnt'], 2)) / sum(pow(group['total_cnt'], 2)), 1)\n", " del group['total_cnt']\n", " del group['retention']\n", " del group['delta_t']\n", " return group\n", "\n", "df = df.groupby(by=['r_history']).progress_apply(cal_stability)\n", "print(\"Stability calculated.\")\n", "df.reset_index(drop = True, inplace = True)\n", "df.drop_duplicates(inplace=True)\n", "df.sort_values(by=['r_history'], inplace=True, ignore_index=True)\n", "\n", "if df.shape[0] > 0:\n", " for idx in tqdm.notebook.tqdm(df.index):\n", " item = df.loc[idx]\n", " index = df[(df['i'] == item['i'] + 1) & (df['r_history'].str.startswith(item['r_history']))].index\n", " df.loc[index, 'last_stability'] = item['stability']\n", " df['factor'] = round(df['stability'] / df['last_stability'], 2)\n", " df = df[(df['i'] >= 2) & (df['group_cnt'] >= 100)]\n", " df['last_recall'] = df['r_history'].map(lambda x: x[-1])\n", " df = df[df.groupby(['i', 'r_history'])['group_cnt'].transform(max) == df['group_cnt']]\n", " df.to_csv('./stability_for_analysis.tsv', sep='\\t', index=None)\n", " print(\"1:again, 2:hard, 3:good, 4:easy\\n\")\n", " print(df[df['r_history'].str.contains(r'^[1-4][^124]*$', regex=True)][['r_history', 'avg_interval', 'avg_retention', 'stability', 'factor', 'group_cnt']].to_string(index=False))\n", " print(\"Analysis saved!\")" ] }, { "cell_type": "markdown", "metadata": { "id": "k_SgzC-auWmu" }, "source": [ "## 2 Optimize parameter" ] }, { "cell_type": "markdown", "metadata": { "id": "WrfBJjqCHEwJ" }, "source": [ "### 2.1 Define the model\n", "\n", "FSRS is a time-series model for predicting memory states." ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "id": "tdYp3GMLhTYm" }, "outputs": [], "source": [ "init_w = [1, 1, 5, -0.5, -0.5, 0.2, 1.4, -0.12, 0.8, 2, -0.2, 0.2, 1]\n", "'''\n", "w[0]: initial_stability_for_again_answer\n", "w[1]: initial_stability_step_per_rating\n", "w[2]: initial_difficulty_for_good_answer\n", "w[3]: initial_difficulty_step_per_rating\n", "w[4]: next_difficulty_step_per_rating\n", "w[5]: next_difficulty_reversion_to_mean_speed (used to avoid ease hell)\n", "w[6]: next_stability_factor_after_success\n", "w[7]: next_stability_stabilization_decay_after_success\n", "w[8]: next_stability_retrievability_gain_after_success\n", "w[9]: next_stability_factor_after_failure\n", "w[10]: next_stability_difficulty_decay_after_success\n", "w[11]: next_stability_stability_gain_after_failure\n", "w[12]: next_stability_retrievability_gain_after_failure\n", "For more details about the parameters, please see: \n", "https://github.com/open-spaced-repetition/fsrs4anki/wiki/Free-Spaced-Repetition-Scheduler\n", "'''\n", "\n", "\n", "class FSRS(nn.Module):\n", " def __init__(self, w):\n", " super(FSRS, self).__init__()\n", " self.w = nn.Parameter(torch.FloatTensor(w))\n", " self.zero = torch.FloatTensor([0.0])\n", "\n", " def forward(self, x, s, d):\n", " '''\n", " :param x: [review interval, review response]\n", " :param s: stability\n", " :param d: difficulty\n", " :return:\n", " '''\n", " if torch.equal(s, self.zero):\n", " # first learn, init memory states\n", " new_s = self.w[0] + self.w[1] * (x[1] - 1)\n", " new_d = self.w[2] + self.w[3] * (x[1] - 3)\n", " new_d = new_d.clamp(1, 10)\n", " else:\n", " r = torch.exp(np.log(0.9) * x[0] / s)\n", " new_d = d + self.w[4] * (x[1] - 3)\n", " new_d = self.mean_reversion(self.w[2], new_d)\n", " new_d = new_d.clamp(1, 10)\n", " # recall\n", " if x[1] > 1:\n", " new_s = s * (1 + torch.exp(self.w[6]) *\n", " (11 - new_d) *\n", " torch.pow(s, self.w[7]) *\n", " (torch.exp((1 - r) * self.w[8]) - 1))\n", " # forget\n", " else:\n", " new_s = self.w[9] * torch.pow(new_d, self.w[10]) * torch.pow(\n", " s, self.w[11]) * torch.exp((1 - r) * self.w[12])\n", " return new_s, new_d\n", "\n", " def loss(self, s, t, r):\n", " return - (r * np.log(0.9) * t / s + (1 - r) * torch.log(1 - torch.exp(np.log(0.9) * t / s)))\n", "\n", " def mean_reversion(self, init, current):\n", " return self.w[5] * init + (1-self.w[5]) * current\n", "\n", "\n", "class WeightClipper(object):\n", " def __init__(self, frequency=1):\n", " self.frequency = frequency\n", "\n", " def __call__(self, module):\n", " if hasattr(module, 'w'):\n", " w = module.w.data\n", " w[0] = w[0].clamp(0.1, 10)\n", " w[1] = w[1].clamp(0.1, 5)\n", " w[2] = w[2].clamp(1, 10)\n", " w[3] = w[3].clamp(-5, -0.1)\n", " w[4] = w[4].clamp(-5, -0.1)\n", " w[5] = w[5].clamp(0, 0.5)\n", " w[6] = w[6].clamp(0, 2)\n", " w[7] = w[7].clamp(-0.2, -0.01)\n", " w[8] = w[8].clamp(0.01, 1.5)\n", " w[9] = w[9].clamp(0.5, 5)\n", " w[10] = w[10].clamp(-2, -0.01)\n", " w[11] = w[11].clamp(0.01, 0.9)\n", " w[12] = w[12].clamp(0.01, 2)\n", " module.w.data = w\n", "\n", "def lineToTensor(line):\n", " ivl = line[0].split(',')\n", " response = line[1].split(',')\n", " tensor = torch.zeros(len(response), 2)\n", " for li, response in enumerate(response):\n", " tensor[li][0] = int(ivl[li])\n", " tensor[li][1] = int(response)\n", " return tensor\n" ] }, { "cell_type": "markdown", "metadata": { "id": "8E1dYfgQLZAC" }, "source": [ "### 2.2 Train the model\n", "\n", "The `revlog_history.tsv` generated before will be used for training the FSRS model." ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "Jht0gneShowU", "outputId": "aaa72b79-b454-483b-d746-df1a353b2c8f" }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "f875359e6b654aeb81d71ba1d3aa10f7", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/225934 [00:00 1) & (dataset['delta_t'] > 0) & (dataset['t_history'].str.count(',0') == 0)]\n", "dataset['tensor'] = dataset.progress_apply(lambda x: lineToTensor(list(zip([x['t_history']], [x['r_history']]))[0]), axis=1)\n", "print(\"Tensorized!\")\n", "\n", "pre_train_set = dataset[dataset['i'] == 2]\n", "# pretrain\n", "epoch_len = len(pre_train_set)\n", "n_epoch = 1\n", "pbar = tqdm.notebook.tqdm(desc=\"pre-train\", colour=\"red\", total=epoch_len*n_epoch)\n", "\n", "for k in range(n_epoch):\n", " for i, (_, row) in enumerate(shuffle(pre_train_set, random_state=2022 + k).iterrows()):\n", " model.train()\n", " optimizer.zero_grad()\n", " output_t = [(model.zero, model.zero)]\n", " for input_t in row['tensor']:\n", " output_t.append(model(input_t, *output_t[-1]))\n", " loss = model.loss(output_t[-1][0], row['delta_t'],\n", " {1: 0, 2: 1, 3: 1, 4: 1}[row['r']])\n", " if np.isnan(loss.data.item()):\n", " # Exception Case\n", " print(row, output_t)\n", " raise Exception('error case')\n", " loss.backward()\n", " optimizer.step()\n", " model.apply(clipper)\n", " pbar.update()\n", "pbar.close()\n", "for name, param in model.named_parameters():\n", " print(f\"{name}: {list(map(lambda x: round(float(x), 4),param))}\")\n", "\n", "train_set = dataset[dataset['i'] > 2]\n", "epoch_len = len(train_set)\n", "n_epoch = 1\n", "print_len = max(epoch_len*n_epoch // 10, 1)\n", "pbar = tqdm.notebook.tqdm(desc=\"train\", colour=\"red\", total=epoch_len*n_epoch)\n", "\n", "for k in range(n_epoch):\n", " for i, (_, row) in enumerate(shuffle(train_set, random_state=2022 + k).iterrows()):\n", " model.train()\n", " optimizer.zero_grad()\n", " output_t = [(model.zero, model.zero)]\n", " for input_t in row['tensor']:\n", " output_t.append(model(input_t, *output_t[-1]))\n", " loss = model.loss(output_t[-1][0], row['delta_t'],\n", " {1: 0, 2: 1, 3: 1, 4: 1}[row['r']])\n", " if np.isnan(loss.data.item()):\n", " # Exception Case\n", " print(row, output_t)\n", " raise Exception('error case')\n", " loss.backward()\n", " for param in model.parameters():\n", " param.grad[:2] = torch.zeros(2)\n", " optimizer.step()\n", " model.apply(clipper)\n", " pbar.update()\n", "\n", " if (k * epoch_len + i) % print_len == 0:\n", " print(f\"iteration: {k * epoch_len + i + 1}\")\n", " for name, param in model.named_parameters():\n", " print(f\"{name}: {list(map(lambda x: round(float(x), 4),param))}\")\n", "pbar.close()\n", "\n", "w = list(map(lambda x: round(float(x), 4), dict(model.named_parameters())['w'].data))\n", "\n", "print(\"\\nTraining finished!\")\n" ] }, { "cell_type": "markdown", "metadata": { "id": "BZ4S2l7BWfzr" }, "source": [ "### 2.3 Result\n", "\n", "Copy the optimal parameters for FSRS for you in the output of next code cell after running." ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "NTnPSDA2QpUu", "outputId": "49f487b9-69a7-4e96-b35a-7e027f478fbd" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "var w = [1.014, 2.2933, 4.9444, -1.1646, -0.9942, 0.0227, 1.3911, -0.0498, 0.7376, 1.7016, -0.4742, 0.602, 0.9946];\n" ] } ], "source": [ "print(f\"var w = {w};\")" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### 2.4 Preview" ] }, { "cell_type": "markdown", "metadata": { "id": "I_zsoDyTaTrT" }, "source": [ "You can see the memory states and intervals generated by FSRS as if you press the good in each review at the due date scheduled by FSRS." ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "iws4rtP1WKBT", "outputId": "890d0287-1a17-4c59-fbbf-ee54d79cd383" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1:again, 2:hard, 3:good, 4:easy\n", "\n", "first rating: 1\n", "rating history: 1,3,3,3,3,3,3,3,3,3,3\n", "interval history: 0,1,2,4,9,19,39,79,159,315,618\n", "difficulty history: 0,7.3,7.2,7.2,7.1,7.1,7.0,7.0,6.9,6.9,6.8\n", "\n", "first rating: 2\n", "rating history: 2,3,3,3,3,3,3,3,3,3,3\n", "interval history: 0,3,8,19,44,100,222,485,1041,2193,4543\n", "difficulty history: 0,6.1,6.1,6.1,6.0,6.0,6.0,6.0,5.9,5.9,5.9\n", "\n", "first rating: 3\n", "rating history: 3,3,3,3,3,3,3,3,3,3,3\n", "interval history: 0,6,16,42,107,265,639,1502,3446,7725,16941\n", "difficulty history: 0,4.9,4.9,4.9,4.9,4.9,4.9,4.9,4.9,4.9,4.9\n", "\n", "first rating: 4\n", "rating history: 4,3,3,3,3,3,3,3,3,3,3\n", "interval history: 0,8,24,69,192,515,1339,3374,8256,19643,45511\n", "difficulty history: 0,3.8,3.8,3.8,3.9,3.9,3.9,3.9,4.0,4.0,4.0\n", "\n" ] } ], "source": [ "requestRetention = 0.9 # recommended setting: 0.8 ~ 0.9\n", "\n", "\n", "class Collection:\n", " def __init__(self, w):\n", " self.model = FSRS(w)\n", "\n", " def states(self, t_history, r_history):\n", " with torch.no_grad():\n", " line_tensor = lineToTensor(list(zip([t_history], [r_history]))[0])\n", " output_t = [(self.model.zero, self.model.zero)]\n", " for input_t in line_tensor:\n", " output_t.append(self.model(input_t, *output_t[-1]))\n", " return output_t[-1]\n", "\n", "\n", "my_collection = Collection(w)\n", "print(\"1:again, 2:hard, 3:good, 4:easy\\n\")\n", "for first_rating in (1,2,3,4):\n", " print(f'first rating: {first_rating}')\n", " t_history = \"0\"\n", " d_history = \"0\"\n", " r_history = f\"{first_rating}\" # the first rating of the new card\n", " # print(\"stability, difficulty, lapses\")\n", " for i in range(10):\n", " states = my_collection.states(t_history, r_history)\n", " # print('{0:9.2f} {1:11.2f} {2:7.0f}'.format(\n", " # *list(map(lambda x: round(float(x), 4), states))))\n", " next_t = max(round(float(np.log(requestRetention)/np.log(0.9) * states[0])), 1)\n", " difficulty = round(float(states[1]), 1)\n", " t_history += f',{int(next_t)}'\n", " d_history += f',{difficulty}'\n", " r_history += f\",3\"\n", " print(f\"rating history: {r_history}\")\n", " print(f\"interval history: {t_history}\")\n", " print(f\"difficulty history: {d_history}\")\n", " print('')\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can change the `test_rating_sequence` to see the scheduling intervals in different ratings." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(tensor(5.6006), tensor(4.9444))\n", "(tensor(15.8482), tensor(4.9444))\n", "(tensor(41.8218), tensor(4.9444))\n", "(tensor(106.7919), tensor(4.9444))\n", "(tensor(264.7672), tensor(4.9444))\n", "(tensor(21.6391), tensor(6.8877))\n", "(tensor(4.2755), tensor(8.7868))\n", "(tensor(6.9117), tensor(8.6996))\n", "(tensor(11.5771), tensor(8.6143))\n", "(tensor(19.6411), tensor(8.5310))\n", "(tensor(33.1676), tensor(8.4496))\n", "(tensor(55.5986), tensor(8.3701))\n", "rating history: 3,3,3,3,3,1,1,3,3,3,3,3\n", "interval history: 0,6,16,42,107,265,22,4,7,12,20,33,56\n", "difficulty history: 0,4.9,4.9,4.9,4.9,4.9,6.9,8.8,8.7,8.6,8.5,8.4,8.4\n" ] } ], "source": [ "test_rating_sequence = \"3,3,3,3,3,1,1,3,3,3,3,3\"\n", "requestRetention = 0.9 # recommended setting: 0.8 ~ 0.9\n", "easyBonus = 1.3\n", "hardInterval = 1.2\n", "\n", "t_history = \"0\"\n", "d_history = \"0\"\n", "for i in range(len(test_rating_sequence.split(','))):\n", " rating = test_rating_sequence[2*i]\n", " last_t = int(t_history.split(',')[-1])\n", " r_history = test_rating_sequence[:2*i+1]\n", " states = my_collection.states(t_history, r_history)\n", " print(states)\n", " next_t = max(1,round(float(np.log(requestRetention)/np.log(0.9) * states[0])))\n", " if rating == '4':\n", " next_t = round(next_t * easyBonus)\n", " elif rating == '2':\n", " next_t = round(last_t * hardInterval)\n", " t_history += f',{int(next_t)}'\n", " difficulty = round(float(states[1]), 1)\n", " d_history += f',{difficulty}'\n", "print(f\"rating history: {test_rating_sequence}\")\n", "print(f\"interval history: {t_history}\")\n", "print(f\"difficulty history: {d_history}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 2.5 Predict memory states and distribution of difficulty\n", "\n", "Predict memory states for each review and save them in `prediction.tsv`.\n", "\n", "Meanwhile, it will count the distribution of difficulty." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "f99be96c169b45f8a03d1355e71b679a", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/119670 [00:00" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "base = 1.01\n", "index_len = 793\n", "index_offset = 200\n", "d_range = 10\n", "d_offset = 1\n", "r_time = 8\n", "f_time = 25\n", "max_time = 200000\n", "\n", "type_block = dict()\n", "type_count = dict()\n", "type_time = dict()\n", "last_t = type_sequence[0]\n", "type_block[last_t] = 1\n", "type_count[last_t] = 1\n", "type_time[last_t] = time_sequence[0]\n", "for i,t in enumerate(type_sequence[1:]):\n", " type_count[t] = type_count.setdefault(t, 0) + 1\n", " type_time[t] = type_time.setdefault(t, 0) + time_sequence[i]\n", " if t != last_t:\n", " type_block[t] = type_block.setdefault(t, 0) + 1\n", " last_t = t\n", "\n", "r_time = round(type_time[1]/type_count[1]/1000, 1)\n", "\n", "if 2 in type_count and 2 in type_block:\n", " f_time = round(type_time[2]/type_block[2]/1000 + r_time, 1)\n", "\n", "print(f\"average time for failed cards: {f_time}s\")\n", "print(f\"average time for recalled cards: {r_time}s\")\n", "\n", "def stability2index(stability):\n", " return int(round(np.log(stability) / np.log(base)) + index_offset)\n", "\n", "def init_stability(d):\n", " return max(((d - w[2]) / w[3] + 2) * w[1] + w[0], np.power(base, -index_offset))\n", "\n", "def cal_next_recall_stability(s, r, d, response):\n", " if response == 1:\n", " return s * (1 + np.exp(w[6]) * (11 - d) * np.power(s, w[7]) * (np.exp((1 - r) * w[8]) - 1))\n", " else:\n", " return w[9] * np.power(d, w[10]) * np.power(s, w[11]) * np.exp((1 - r) * w[12])\n", "\n", "\n", "stability_list = np.array([np.power(base, i - index_offset) for i in range(index_len)])\n", "print(f\"terminal stability: {stability_list.max(): .2f}\")\n", "df = pd.DataFrame(columns=[\"retention\", \"difficulty\", \"time\"])\n", "\n", "for percentage in tqdm.notebook.tqdm(range(96, 66, -2)):\n", " recall = percentage / 100\n", " time_list = np.zeros((d_range, index_len))\n", " time_list[:,:-1] = max_time\n", " for d in range(d_range, 0, -1):\n", " s0 = init_stability(d)\n", " s0_index = stability2index(s0)\n", " diff = max_time\n", " while diff > 0.1:\n", " s0_time = time_list[d - 1][s0_index]\n", " for s_index in range(index_len - 2, -1, -1):\n", " stability = stability_list[s_index];\n", " interval = max(1, round(stability * np.log(recall) / np.log(0.9)))\n", " p_recall = np.power(0.9, interval / stability)\n", " recall_s = cal_next_recall_stability(stability, p_recall, d, 1)\n", " forget_d = min(d + d_offset, 10)\n", " forget_s = cal_next_recall_stability(stability, p_recall, forget_d, 0)\n", " recall_s_index = min(stability2index(recall_s), index_len - 1)\n", " forget_s_index = min(max(stability2index(forget_s), 0), index_len - 1)\n", " recall_time = time_list[d - 1][recall_s_index] + r_time\n", " forget_time = time_list[forget_d - 1][forget_s_index] + f_time\n", " exp_time = p_recall * recall_time + (1.0 - p_recall) * forget_time\n", " if exp_time < time_list[d - 1][s_index]:\n", " time_list[d - 1][s_index] = exp_time\n", " diff = s0_time - time_list[d - 1][s0_index]\n", " df.loc[0 if pd.isnull(df.index.max()) else df.index.max() + 1] = [recall, d, s0_time]\n", "\n", "df.sort_values(by=[\"difficulty\", \"retention\"], inplace=True)\n", "df.to_csv(\"./expected_time.csv\", index=False)\n", "print(\"expected_time.csv saved.\")\n", "\n", "optimal_retention_list = np.zeros(10)\n", "for d in range(1, d_range+1):\n", " retention = df[df[\"difficulty\"] == d][\"retention\"]\n", " time = df[df[\"difficulty\"] == d][\"time\"]\n", " optimal_retention = retention.iat[time.argmin()]\n", " optimal_retention_list[d-1] = optimal_retention\n", " plt.plot(retention, time, label=f\"d={d}, r={optimal_retention}\")\n", "print(f\"\\n-----suggested retention (experimental): {np.inner(difficulty_distribution_padding, optimal_retention_list):.2f}-----\")\n", "plt.ylabel(\"expected time (second)\")\n", "plt.xlabel(\"retention\")\n", "plt.legend()\n", "plt.grid()\n", "plt.semilogy()\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 4 Evaluate the model\n", "\n", "Evaluate the model with the log loss. It will compare the log loss between initial model and trained model." ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "ca34a71cc6314383b98b249b4754f2a7", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/225934 [00:00