diff --git "a/fsrs4anki_optimizer.ipynb" "b/fsrs4anki_optimizer.ipynb" deleted file mode 100644--- "a/fsrs4anki_optimizer.ipynb" +++ /dev/null @@ -1,1227 +0,0 @@ -{ - "cells": [ - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# FSRS4Anki v3.13.3 Optimizer" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": { - "id": "lurCmW0Jqz3s" - }, - "source": [ - "[![open in colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/open-spaced-repetition/fsrs4anki/blob/v3.13.3/fsrs4anki_optimizer.ipynb)\n", - "\n", - "↑ Click the above button to open the optimizer on Google Colab.\n", - "\n", - "> If you can't see the button and are located in the Chinese Mainland, please use a proxy or VPN." - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": { - "id": "wG7bBfGJFbMr" - }, - "source": [ - "Upload your **Anki Deck Package (.apkg)** file or **Anki Collection Package (.colpkg)** file on the `Left sidebar -> Files`, drag and drop your file in the current directory (not the `sample_data` directory). \n", - "\n", - "No need to include media. Need to include scheduling information. \n", - "\n", - "> If you use the latest version of Anki, please check the box `Support older Anki versions (slower/larger files)` when you export.\n", - "\n", - "You can export it via `File -> Export...` or `Ctrl + E` in the main window of Anki.\n", - "\n", - "Then replace the `filename` with yours in the next code cell. And set the `timezone` and `next_day_starts_at` which can be found in your preferences of Anki.\n", - "\n", - "After that, just run all (`Runtime -> Run all` or `Ctrl + F9`) and wait for minutes. You can see the optimal parameters in section **2.3 Result**. Copy them, replace the parameters in `fsrs4anki_scheduler.js`, and paste them into the custom scheduling of your deck options (require Anki version >= 2.1.55).\n", - "\n", - "**NOTE**: The default output is generated from my review logs. If you find the output is the same as mine, maybe your notebook hasn't run there.\n", - "\n", - "**Contribute to SRS Research**: If you want to share your data with me, please fill this form: https://forms.gle/KaojsBbhMCytaA7h8" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": { - "id": "iqP70_-3EUhi" - }, - "outputs": [], - "source": [ - "# Here are some settings that you need to replace before running this optimizer.\n", - "\n", - "filename = \"collection-2022-09-18@13-21-58.colpkg\"\n", - "# If you upload deck file, replace it with your deck filename. E.g., ALL__Learning.apkg\n", - "# If you upload collection file, replace it with your colpgk filename. E.g., collection-2022-09-18@13-21-58.colpkg\n", - "\n", - "# Replace it with your timezone. I'm in China, so I use Asia/Shanghai.\n", - "# You can find your timezone here: https://gist.github.com/heyalexej/8bf688fd67d7199be4a1682b3eec7568\n", - "timezone = 'Asia/Shanghai'\n", - "\n", - "# Replace it with your Anki's setting in Preferences -> Scheduling.\n", - "next_day_starts_at = 4\n", - "\n", - "# Replace it if you don't want the optimizer to use the review logs before a specific date.\n", - "revlog_start_date = \"2006-10-05\"\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "bLFVNmG2qd06" - }, - "source": [ - "## 1 Build dataset" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "EkzFeKawqgbs" - }, - "source": [ - "### 1.1 Extract Anki collection & deck file" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "KD2js_wEr_Bs", - "outputId": "42653d9e-316e-40bc-bd1d-f3a0e2b246c7" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Extract successfully!\n" - ] - } - ], - "source": [ - "import zipfile\n", - "import sqlite3\n", - "import time\n", - "import tqdm\n", - "import pandas as pd\n", - "import numpy as np\n", - "import os\n", - "from datetime import timedelta, datetime\n", - "import matplotlib.pyplot as plt\n", - "import math\n", - "import sys\n", - "import torch\n", - "from torch import nn\n", - "from sklearn.utils import shuffle\n", - "# Extract the collection file or deck file to get the .anki21 database.\n", - "with zipfile.ZipFile(f'./{filename}', 'r') as zip_ref:\n", - " zip_ref.extractall('./')\n", - " print(\"Extract successfully!\")\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "dKpy4VfqGmaL" - }, - "source": [ - "### 1.2 Create time-series feature & analysis\n", - "\n", - "The following code cell will extract the review logs from your Anki collection and preprocess them to a trainset which is saved in `revlog_history.tsv`.\n", - "\n", - "The time-series features are important in optimizing the model's parameters. For more detail, please see my paper: https://www.maimemo.com/paper/\n", - "\n", - "Then it will generate a concise analysis for your review logs. \n", - "\n", - "- The `r_history` is the history of ratings on each review. `3,3,3,1` means that you press `Good, Good, Good, Again`. It only contains the first rating for each card on the review date, i.e., when you press `Again` in review and `Good` in relearning steps 10min later, only `Again` will be recorded.\n", - "- The `avg_interval` is the actual average interval after you rate your cards as the `r_history`. It could be longer than the interval given by Anki's built-in scheduler because you reviewed some overdue cards.\n", - "- The `avg_retention` is the average retention after you press as the `r_history`. `Again` counts as failed recall, and `Hard, Good and Easy` count as successful recall. Retention is the percentage of your successful recall.\n", - "- The `stability` is the estimated memory state variable, which is an approximate interval that leads to 90% retention.\n", - "- The `factor` is `stability / previous stability`.\n", - "- The `group_cnt` is the number of review logs that have the same `r_history`." - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "J2IIaY3PDaaG", - "outputId": "607916c9-da95-48dd-fdab-6bd83fbbbb40" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "revlog.csv saved.\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "c9bb754c7ac441068199f88788b35a74", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " 0%| | 0/30711 [00:00 0)].copy()\n", - "df['create_date'] = pd.to_datetime(df['cid'] // 1000, unit='s')\n", - "df['create_date'] = df['create_date'].dt.tz_localize(\n", - " 'UTC').dt.tz_convert(timezone)\n", - "df['review_date'] = pd.to_datetime(df['id'] // 1000, unit='s')\n", - "df['review_date'] = df['review_date'].dt.tz_localize(\n", - " 'UTC').dt.tz_convert(timezone)\n", - "df.drop(df[df['review_date'].dt.year < 2006].index, inplace=True)\n", - "df.sort_values(by=['cid', 'id'], inplace=True, ignore_index=True)\n", - "type_sequence = np.array(df['type'])\n", - "time_sequence = np.array(df['time'])\n", - "df.to_csv(\"revlog.csv\", index=False)\n", - "print(\"revlog.csv saved.\")\n", - "df = df[(df['type'] == 0) | (df['type'] == 1)].copy()\n", - "df['real_days'] = df['review_date'] - timedelta(hours=next_day_starts_at)\n", - "df['real_days'] = pd.DatetimeIndex(df['real_days'].dt.floor('D')).to_julian_date()\n", - "df.drop_duplicates(['cid', 'real_days'], keep='first', inplace=True)\n", - "df['delta_t'] = df.real_days.diff()\n", - "df.dropna(inplace=True)\n", - "df['delta_t'] = df['delta_t'].astype(dtype=int)\n", - "df['i'] = 1\n", - "df['r_history'] = \"\"\n", - "df['t_history'] = \"\"\n", - "col_idx = {key: i for i, key in enumerate(df.columns)}\n", - "\n", - "\n", - "# code from https://github.com/L-M-Sherlock/anki_revlog_analysis/blob/main/revlog_analysis.py\n", - "def get_feature(x):\n", - " for idx, log in enumerate(x.itertuples()):\n", - " if idx == 0:\n", - " x.iloc[idx, col_idx['delta_t']] = 0\n", - " if idx == x.shape[0] - 1:\n", - " break\n", - " x.iloc[idx + 1, col_idx['i']] = x.iloc[idx, col_idx['i']] + 1\n", - " x.iloc[idx + 1, col_idx['t_history']] = f\"{x.iloc[idx, col_idx['t_history']]},{x.iloc[idx, col_idx['delta_t']]}\"\n", - " x.iloc[idx + 1, col_idx['r_history']] = f\"{x.iloc[idx, col_idx['r_history']]},{x.iloc[idx, col_idx['r']]}\"\n", - " return x\n", - "\n", - "tqdm.notebook.tqdm.pandas()\n", - "df = df.groupby('cid', as_index=False).progress_apply(get_feature)\n", - "df = df[df['id'] >= time.mktime(datetime.strptime(revlog_start_date, \"%Y-%m-%d\").timetuple()) * 1000]\n", - "df[\"t_history\"] = df[\"t_history\"].map(lambda x: x[1:] if len(x) > 1 else x)\n", - "df[\"r_history\"] = df[\"r_history\"].map(lambda x: x[1:] if len(x) > 1 else x)\n", - "df.to_csv('revlog_history.tsv', sep=\"\\t\", index=False)\n", - "print(\"Trainset saved.\")\n", - "\n", - "def cal_retention(group: pd.DataFrame) -> pd.DataFrame:\n", - " group['retention'] = round(group['r'].map(lambda x: {1: 0, 2: 1, 3: 1, 4: 1}[x]).mean(), 4)\n", - " group['total_cnt'] = group.shape[0]\n", - " return group\n", - "\n", - "df = df.groupby(by=['r_history', 'delta_t']).progress_apply(cal_retention)\n", - "print(\"Retention calculated.\")\n", - "df = df.drop(columns=['id', 'cid', 'usn', 'ivl', 'last_lvl', 'factor', 'time', 'type', 'create_date', 'review_date', 'real_days', 'r', 't_history'])\n", - "df.drop_duplicates(inplace=True)\n", - "df = df[(df['retention'] < 1) & (df['retention'] > 0)]\n", - "\n", - "def cal_stability(group: pd.DataFrame) -> pd.DataFrame:\n", - " if group['i'].values[0] > 1:\n", - " r_ivl_cnt = sum(group['delta_t'] * group['retention'].map(np.log) * pow(group['total_cnt'], 2))\n", - " ivl_ivl_cnt = sum(group['delta_t'].map(lambda x: x ** 2) * pow(group['total_cnt'], 2))\n", - " group['stability'] = round(np.log(0.9) / (r_ivl_cnt / ivl_ivl_cnt), 1)\n", - " else:\n", - " group['stability'] = 0.0\n", - " group['group_cnt'] = sum(group['total_cnt'])\n", - " group['avg_retention'] = round(sum(group['retention'] * pow(group['total_cnt'], 2)) / sum(pow(group['total_cnt'], 2)), 3)\n", - " group['avg_interval'] = round(sum(group['delta_t'] * pow(group['total_cnt'], 2)) / sum(pow(group['total_cnt'], 2)), 1)\n", - " del group['total_cnt']\n", - " del group['retention']\n", - " del group['delta_t']\n", - " return group\n", - "\n", - "df = df.groupby(by=['r_history']).progress_apply(cal_stability)\n", - "print(\"Stability calculated.\")\n", - "df.reset_index(drop = True, inplace = True)\n", - "df.drop_duplicates(inplace=True)\n", - "df.sort_values(by=['r_history'], inplace=True, ignore_index=True)\n", - "\n", - "if df.shape[0] > 0:\n", - " for idx in tqdm.notebook.tqdm(df.index):\n", - " item = df.loc[idx]\n", - " index = df[(df['i'] == item['i'] + 1) & (df['r_history'].str.startswith(item['r_history']))].index\n", - " df.loc[index, 'last_stability'] = item['stability']\n", - " df['factor'] = round(df['stability'] / df['last_stability'], 2)\n", - " df = df[(df['i'] >= 2) & (df['group_cnt'] >= 100)]\n", - " df['last_recall'] = df['r_history'].map(lambda x: x[-1])\n", - " df = df[df.groupby(['i', 'r_history'])['group_cnt'].transform(max) == df['group_cnt']]\n", - " df.to_csv('./stability_for_analysis.tsv', sep='\\t', index=None)\n", - " print(\"1:again, 2:hard, 3:good, 4:easy\\n\")\n", - " print(df[df['r_history'].str.contains(r'^[1-4][^124]*$', regex=True)][['r_history', 'avg_interval', 'avg_retention', 'stability', 'factor', 'group_cnt']].to_string(index=False))\n", - " print(\"Analysis saved!\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "k_SgzC-auWmu" - }, - "source": [ - "## 2 Optimize parameter" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "WrfBJjqCHEwJ" - }, - "source": [ - "### 2.1 Define the model\n", - "\n", - "FSRS is a time-series model for predicting memory states." - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": { - "id": "tdYp3GMLhTYm" - }, - "outputs": [], - "source": [ - "init_w = [1, 1, 5, -0.5, -0.5, 0.2, 1.4, -0.12, 0.8, 2, -0.2, 0.2, 1]\n", - "'''\n", - "w[0]: initial_stability_for_again_answer\n", - "w[1]: initial_stability_step_per_rating\n", - "w[2]: initial_difficulty_for_good_answer\n", - "w[3]: initial_difficulty_step_per_rating\n", - "w[4]: next_difficulty_step_per_rating\n", - "w[5]: next_difficulty_reversion_to_mean_speed (used to avoid ease hell)\n", - "w[6]: next_stability_factor_after_success\n", - "w[7]: next_stability_stabilization_decay_after_success\n", - "w[8]: next_stability_retrievability_gain_after_success\n", - "w[9]: next_stability_factor_after_failure\n", - "w[10]: next_stability_difficulty_decay_after_success\n", - "w[11]: next_stability_stability_gain_after_failure\n", - "w[12]: next_stability_retrievability_gain_after_failure\n", - "For more details about the parameters, please see: \n", - "https://github.com/open-spaced-repetition/fsrs4anki/wiki/Free-Spaced-Repetition-Scheduler\n", - "'''\n", - "\n", - "\n", - "class FSRS(nn.Module):\n", - " def __init__(self, w):\n", - " super(FSRS, self).__init__()\n", - " self.w = nn.Parameter(torch.FloatTensor(w))\n", - " self.zero = torch.FloatTensor([0.0])\n", - "\n", - " def forward(self, x, s, d):\n", - " '''\n", - " :param x: [review interval, review response]\n", - " :param s: stability\n", - " :param d: difficulty\n", - " :return:\n", - " '''\n", - " if torch.equal(s, self.zero):\n", - " # first learn, init memory states\n", - " new_s = self.w[0] + self.w[1] * (x[1] - 1)\n", - " new_d = self.w[2] + self.w[3] * (x[1] - 3)\n", - " new_d = new_d.clamp(1, 10)\n", - " else:\n", - " r = torch.exp(np.log(0.9) * x[0] / s)\n", - " new_d = d + self.w[4] * (x[1] - 3)\n", - " new_d = self.mean_reversion(self.w[2], new_d)\n", - " new_d = new_d.clamp(1, 10)\n", - " # recall\n", - " if x[1] > 1:\n", - " new_s = s * (1 + torch.exp(self.w[6]) *\n", - " (11 - new_d) *\n", - " torch.pow(s, self.w[7]) *\n", - " (torch.exp((1 - r) * self.w[8]) - 1))\n", - " # forget\n", - " else:\n", - " new_s = self.w[9] * torch.pow(new_d, self.w[10]) * torch.pow(\n", - " s, self.w[11]) * torch.exp((1 - r) * self.w[12])\n", - " return new_s, new_d\n", - "\n", - " def loss(self, s, t, r):\n", - " return - (r * np.log(0.9) * t / s + (1 - r) * torch.log(1 - torch.exp(np.log(0.9) * t / s)))\n", - "\n", - " def mean_reversion(self, init, current):\n", - " return self.w[5] * init + (1-self.w[5]) * current\n", - "\n", - "\n", - "class WeightClipper(object):\n", - " def __init__(self, frequency=1):\n", - " self.frequency = frequency\n", - "\n", - " def __call__(self, module):\n", - " if hasattr(module, 'w'):\n", - " w = module.w.data\n", - " w[0] = w[0].clamp(0.1, 10)\n", - " w[1] = w[1].clamp(0.1, 5)\n", - " w[2] = w[2].clamp(1, 10)\n", - " w[3] = w[3].clamp(-5, -0.1)\n", - " w[4] = w[4].clamp(-5, -0.1)\n", - " w[5] = w[5].clamp(0, 0.5)\n", - " w[6] = w[6].clamp(0, 2)\n", - " w[7] = w[7].clamp(-0.2, -0.01)\n", - " w[8] = w[8].clamp(0.01, 1.5)\n", - " w[9] = w[9].clamp(0.5, 5)\n", - " w[10] = w[10].clamp(-2, -0.01)\n", - " w[11] = w[11].clamp(0.01, 0.9)\n", - " w[12] = w[12].clamp(0.01, 2)\n", - " module.w.data = w\n", - "\n", - "def lineToTensor(line):\n", - " ivl = line[0].split(',')\n", - " response = line[1].split(',')\n", - " tensor = torch.zeros(len(response), 2)\n", - " for li, response in enumerate(response):\n", - " tensor[li][0] = int(ivl[li])\n", - " tensor[li][1] = int(response)\n", - " return tensor\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "8E1dYfgQLZAC" - }, - "source": [ - "### 2.2 Train the model\n", - "\n", - "The `revlog_history.tsv` generated before will be used for training the FSRS model." - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "Jht0gneShowU", - "outputId": "aaa72b79-b454-483b-d746-df1a353b2c8f" - }, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "749aeda9cb624986ae2872489bcc6762", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " 0%| | 0/225934 [00:00 1) & (dataset['delta_t'] > 0) & (dataset['t_history'].str.count(',0') == 0)]\n", - "dataset['tensor'] = dataset.progress_apply(lambda x: lineToTensor(list(zip([x['t_history']], [x['r_history']]))[0]), axis=1)\n", - "print(\"Tensorized!\")\n", - "\n", - "pre_train_set = dataset[dataset['i'] == 2]\n", - "# pretrain\n", - "epoch_len = len(pre_train_set)\n", - "n_epoch = 1\n", - "pbar = tqdm.notebook.tqdm(desc=\"pre-train\", colour=\"red\", total=epoch_len*n_epoch)\n", - "\n", - "for k in range(n_epoch):\n", - " for i, (_, row) in enumerate(shuffle(pre_train_set, random_state=2022 + k).iterrows()):\n", - " model.train()\n", - " optimizer.zero_grad()\n", - " output_t = [(model.zero, model.zero)]\n", - " for input_t in row['tensor']:\n", - " output_t.append(model(input_t, *output_t[-1]))\n", - " loss = model.loss(output_t[-1][0], row['delta_t'],\n", - " {1: 0, 2: 1, 3: 1, 4: 1}[row['r']])\n", - " if np.isnan(loss.data.item()):\n", - " # Exception Case\n", - " print(row, output_t)\n", - " raise Exception('error case')\n", - " loss.backward()\n", - " optimizer.step()\n", - " model.apply(clipper)\n", - " pbar.update()\n", - "pbar.close()\n", - "for name, param in model.named_parameters():\n", - " print(f\"{name}: {list(map(lambda x: round(float(x), 4),param))}\")\n", - "\n", - "train_set = dataset[dataset['i'] > 2]\n", - "epoch_len = len(train_set)\n", - "n_epoch = 1\n", - "print_len = max(epoch_len*n_epoch // 10, 1)\n", - "pbar = tqdm.notebook.tqdm(desc=\"train\", colour=\"red\", total=epoch_len*n_epoch)\n", - "\n", - "for k in range(n_epoch):\n", - " for i, (_, row) in enumerate(shuffle(train_set, random_state=2022 + k).iterrows()):\n", - " model.train()\n", - " optimizer.zero_grad()\n", - " output_t = [(model.zero, model.zero)]\n", - " for input_t in row['tensor']:\n", - " output_t.append(model(input_t, *output_t[-1]))\n", - " loss = model.loss(output_t[-1][0], row['delta_t'],\n", - " {1: 0, 2: 1, 3: 1, 4: 1}[row['r']])\n", - " if np.isnan(loss.data.item()):\n", - " # Exception Case\n", - " print(row, output_t)\n", - " raise Exception('error case')\n", - " loss.backward()\n", - " for param in model.parameters():\n", - " param.grad[:2] = torch.zeros(2)\n", - " optimizer.step()\n", - " model.apply(clipper)\n", - " pbar.update()\n", - "\n", - " if (k * epoch_len + i) % print_len == 0:\n", - " print(f\"iteration: {k * epoch_len + i + 1}\")\n", - " for name, param in model.named_parameters():\n", - " print(f\"{name}: {list(map(lambda x: round(float(x), 4),param))}\")\n", - "pbar.close()\n", - "\n", - "w = list(map(lambda x: round(float(x), 4), dict(model.named_parameters())['w'].data))\n", - "\n", - "print(\"\\nTraining finished!\")\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "BZ4S2l7BWfzr" - }, - "source": [ - "### 2.3 Result\n", - "\n", - "Copy the optimal parameters for FSRS for you in the output of next code cell after running." - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "NTnPSDA2QpUu", - "outputId": "49f487b9-69a7-4e96-b35a-7e027f478fbd" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "var w = [1.014, 2.2933, 4.9444, -1.1646, -0.9942, 0.0227, 1.3911, -0.0498, 0.7376, 1.7016, -0.4742, 0.602, 0.9946];\n" - ] - } - ], - "source": [ - "print(f\"var w = {w};\")" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 2.4 Preview" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "I_zsoDyTaTrT" - }, - "source": [ - "You can see the memory states and intervals generated by FSRS as if you press the good in each review at the due date scheduled by FSRS." - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "iws4rtP1WKBT", - "outputId": "890d0287-1a17-4c59-fbbf-ee54d79cd383" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "1:again, 2:hard, 3:good, 4:easy\n", - "\n", - "first rating: 1\n", - "rating history: 1,3,3,3,3,3,3,3,3,3,3\n", - "interval history: 0,1,2,4,9,19,39,79,159,315,618\n", - "difficulty history: 0,7.3,7.2,7.2,7.1,7.1,7.0,7.0,6.9,6.9,6.8\n", - "\n", - "first rating: 2\n", - "rating history: 2,3,3,3,3,3,3,3,3,3,3\n", - "interval history: 0,3,8,19,44,100,222,485,1041,2193,4543\n", - "difficulty history: 0,6.1,6.1,6.1,6.0,6.0,6.0,6.0,5.9,5.9,5.9\n", - "\n", - "first rating: 3\n", - "rating history: 3,3,3,3,3,3,3,3,3,3,3\n", - "interval history: 0,6,16,42,107,265,639,1502,3446,7725,16941\n", - "difficulty history: 0,4.9,4.9,4.9,4.9,4.9,4.9,4.9,4.9,4.9,4.9\n", - "\n", - "first rating: 4\n", - "rating history: 4,3,3,3,3,3,3,3,3,3,3\n", - "interval history: 0,8,24,69,192,515,1339,3374,8256,19643,45511\n", - "difficulty history: 0,3.8,3.8,3.8,3.9,3.9,3.9,3.9,4.0,4.0,4.0\n", - "\n" - ] - } - ], - "source": [ - "requestRetention = 0.9 # recommended setting: 0.8 ~ 0.9\n", - "\n", - "\n", - "class Collection:\n", - " def __init__(self, w):\n", - " self.model = FSRS(w)\n", - "\n", - " def states(self, t_history, r_history):\n", - " with torch.no_grad():\n", - " line_tensor = lineToTensor(list(zip([t_history], [r_history]))[0])\n", - " output_t = [(self.model.zero, self.model.zero)]\n", - " for input_t in line_tensor:\n", - " output_t.append(self.model(input_t, *output_t[-1]))\n", - " return output_t[-1]\n", - "\n", - "\n", - "my_collection = Collection(w)\n", - "print(\"1:again, 2:hard, 3:good, 4:easy\\n\")\n", - "for first_rating in (1,2,3,4):\n", - " print(f'first rating: {first_rating}')\n", - " t_history = \"0\"\n", - " d_history = \"0\"\n", - " r_history = f\"{first_rating}\" # the first rating of the new card\n", - " # print(\"stability, difficulty, lapses\")\n", - " for i in range(10):\n", - " states = my_collection.states(t_history, r_history)\n", - " # print('{0:9.2f} {1:11.2f} {2:7.0f}'.format(\n", - " # *list(map(lambda x: round(float(x), 4), states))))\n", - " next_t = max(round(float(np.log(requestRetention)/np.log(0.9) * states[0])), 1)\n", - " difficulty = round(float(states[1]), 1)\n", - " t_history += f',{int(next_t)}'\n", - " d_history += f',{difficulty}'\n", - " r_history += f\",3\"\n", - " print(f\"rating history: {r_history}\")\n", - " print(f\"interval history: {t_history}\")\n", - " print(f\"difficulty history: {d_history}\")\n", - " print('')\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "You can change the `test_rating_sequence` to see the scheduling intervals in different ratings." - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "(tensor(5.6006), tensor(4.9444))\n", - "(tensor(15.8482), tensor(4.9444))\n", - "(tensor(41.8218), tensor(4.9444))\n", - "(tensor(106.7919), tensor(4.9444))\n", - "(tensor(264.7672), tensor(4.9444))\n", - "(tensor(21.6391), tensor(6.8877))\n", - "(tensor(4.2755), tensor(8.7868))\n", - "(tensor(6.9117), tensor(8.6996))\n", - "(tensor(11.5771), tensor(8.6143))\n", - "(tensor(19.6411), tensor(8.5310))\n", - "(tensor(33.1676), tensor(8.4496))\n", - "(tensor(55.5986), tensor(8.3701))\n", - "rating history: 3,3,3,3,3,1,1,3,3,3,3,3\n", - "interval history: 0,6,16,42,107,265,22,4,7,12,20,33,56\n", - "difficulty history: 0,4.9,4.9,4.9,4.9,4.9,6.9,8.8,8.7,8.6,8.5,8.4,8.4\n" - ] - } - ], - "source": [ - "test_rating_sequence = \"3,3,3,3,3,1,1,3,3,3,3,3\"\n", - "requestRetention = 0.9 # recommended setting: 0.8 ~ 0.9\n", - "easyBonus = 1.3\n", - "hardInterval = 1.2\n", - "\n", - "t_history = \"0\"\n", - "d_history = \"0\"\n", - "for i in range(len(test_rating_sequence.split(','))):\n", - " rating = test_rating_sequence[2*i]\n", - " last_t = int(t_history.split(',')[-1])\n", - " r_history = test_rating_sequence[:2*i+1]\n", - " states = my_collection.states(t_history, r_history)\n", - " print(states)\n", - " next_t = max(1,round(float(np.log(requestRetention)/np.log(0.9) * states[0])))\n", - " if rating == '4':\n", - " next_t = round(next_t * easyBonus)\n", - " elif rating == '2':\n", - " next_t = round(last_t * hardInterval)\n", - " t_history += f',{int(next_t)}'\n", - " difficulty = round(float(states[1]), 1)\n", - " d_history += f',{difficulty}'\n", - "print(f\"rating history: {test_rating_sequence}\")\n", - "print(f\"interval history: {t_history}\")\n", - "print(f\"difficulty history: {d_history}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 2.5 Predict memory states and distribution of difficulty\n", - "\n", - "Predict memory states for each review and save them in `prediction.tsv`.\n", - "\n", - "Meanwhile, it will count the distribution of difficulty." - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "7712ed30c62643f4aa834c840458158c", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " 0%| | 0/119670 [00:00" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "base = 1.01\n", - "index_len = 793\n", - "index_offset = 200\n", - "d_range = 10\n", - "d_offset = 1\n", - "r_time = 8\n", - "f_time = 25\n", - "max_time = 200000\n", - "\n", - "type_block = dict()\n", - "type_count = dict()\n", - "type_time = dict()\n", - "last_t = type_sequence[0]\n", - "type_block[last_t] = 1\n", - "type_count[last_t] = 1\n", - "type_time[last_t] = time_sequence[0]\n", - "for i,t in enumerate(type_sequence[1:]):\n", - " type_count[t] = type_count.setdefault(t, 0) + 1\n", - " type_time[t] = type_time.setdefault(t, 0) + time_sequence[i]\n", - " if t != last_t:\n", - " type_block[t] = type_block.setdefault(t, 0) + 1\n", - " last_t = t\n", - "\n", - "r_time = round(type_time[1]/type_count[1]/1000, 1)\n", - "\n", - "if 2 in type_count and 2 in type_block:\n", - " f_time = round(type_time[2]/type_block[2]/1000 + r_time, 1)\n", - "\n", - "print(f\"average time for failed cards: {f_time}s\")\n", - "print(f\"average time for recalled cards: {r_time}s\")\n", - "\n", - "def stability2index(stability):\n", - " return int(round(np.log(stability) / np.log(base)) + index_offset)\n", - "\n", - "def init_stability(d):\n", - " return max(((d - w[2]) / w[3] + 2) * w[1] + w[0], np.power(base, -index_offset))\n", - "\n", - "def cal_next_recall_stability(s, r, d, response):\n", - " if response == 1:\n", - " return s * (1 + np.exp(w[6]) * (11 - d) * np.power(s, w[7]) * (np.exp((1 - r) * w[8]) - 1))\n", - " else:\n", - " return w[9] * np.power(d, w[10]) * np.power(s, w[11]) * np.exp((1 - r) * w[12])\n", - "\n", - "\n", - "stability_list = np.array([np.power(base, i - index_offset) for i in range(index_len)])\n", - "print(f\"terminal stability: {stability_list.max(): .2f}\")\n", - "df = pd.DataFrame(columns=[\"retention\", \"difficulty\", \"time\"])\n", - "\n", - "for percentage in tqdm.notebook.tqdm(range(96, 66, -2)):\n", - " recall = percentage / 100\n", - " time_list = np.zeros((d_range, index_len))\n", - " time_list[:,:-1] = max_time\n", - " for d in range(d_range, 0, -1):\n", - " s0 = init_stability(d)\n", - " s0_index = stability2index(s0)\n", - " diff = max_time\n", - " while diff > 0.1:\n", - " s0_time = time_list[d - 1][s0_index]\n", - " for s_index in range(index_len - 2, -1, -1):\n", - " stability = stability_list[s_index];\n", - " interval = max(1, round(stability * np.log(recall) / np.log(0.9)))\n", - " p_recall = np.power(0.9, interval / stability)\n", - " recall_s = cal_next_recall_stability(stability, p_recall, d, 1)\n", - " forget_d = min(d + d_offset, 10)\n", - " forget_s = cal_next_recall_stability(stability, p_recall, forget_d, 0)\n", - " recall_s_index = min(stability2index(recall_s), index_len - 1)\n", - " forget_s_index = min(max(stability2index(forget_s), 0), index_len - 1)\n", - " recall_time = time_list[d - 1][recall_s_index] + r_time\n", - " forget_time = time_list[forget_d - 1][forget_s_index] + f_time\n", - " exp_time = p_recall * recall_time + (1.0 - p_recall) * forget_time\n", - " if exp_time < time_list[d - 1][s_index]:\n", - " time_list[d - 1][s_index] = exp_time\n", - " diff = s0_time - time_list[d - 1][s0_index]\n", - " df.loc[0 if pd.isnull(df.index.max()) else df.index.max() + 1] = [recall, d, s0_time]\n", - "\n", - "df.sort_values(by=[\"difficulty\", \"retention\"], inplace=True)\n", - "df.to_csv(\"./expected_time.csv\", index=False)\n", - "print(\"expected_time.csv saved.\")\n", - "\n", - "optimal_retention_list = np.zeros(10)\n", - "for d in range(1, d_range+1):\n", - " retention = df[df[\"difficulty\"] == d][\"retention\"]\n", - " time = df[df[\"difficulty\"] == d][\"time\"]\n", - " optimal_retention = retention.iat[time.argmin()]\n", - " optimal_retention_list[d-1] = optimal_retention\n", - " plt.plot(retention, time, label=f\"d={d}, r={optimal_retention}\")\n", - "print(f\"\\n-----suggested retention (experimental): {np.inner(difficulty_distribution_padding, optimal_retention_list):.2f}-----\")\n", - "plt.ylabel(\"expected time (second)\")\n", - "plt.xlabel(\"retention\")\n", - "plt.legend()\n", - "plt.grid()\n", - "plt.semilogy()\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 4 Evaluate the model\n", - "\n", - "Evaluate the model with the log loss. It will compare the log loss between initial model and trained model." - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "63e8453d80124699b1f40395051577d7", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " 0%| | 0/225934 [00:00