diff --git "a/fsrs4anki_optimizer.ipynb" "b/fsrs4anki_optimizer.ipynb" --- "a/fsrs4anki_optimizer.ipynb" +++ "b/fsrs4anki_optimizer.ipynb" @@ -1,1180 +1,1227 @@ { - "cells": [ - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# FSRS4Anki v3.10.1 Optimizer" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": { - "id": "lurCmW0Jqz3s" - }, - "source": [ - "[![open in colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/open-spaced-repetition/fsrs4anki/blob/v3.10.1/fsrs4anki_optimizer.ipynb)\n", - "\n", - "↑ Click the above button to open the optimizer on Google Colab.\n", - "\n", - "> If you can't see the button and are located in the Chinese Mainland, please use a proxy or VPN." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "wG7bBfGJFbMr" - }, - "source": [ - "Upload your **Anki Deck Package (.apkg)** file or **Anki Collection Package (.colpkg)** file on the `Left sidebar -> Files`, drag and drop your file in the current directory (not the `sample_data` directory). \n", - "\n", - "No need to include media. Need to include scheduling information. \n", - "\n", - "> If you use the latest version of Anki, please check the box `Support older Anki versions (slower/larger files)` when you export.\n", - "\n", - "You can export it via `File -> Export...` or `Ctrl + E` in the main window of Anki.\n", - "\n", - "Then replace the `filename` with yours in the next code cell. And set the `timezone` and `next_day_starts_at` which can be found in your preferences of Anki.\n", - "\n", - "After that, just run all (`Runtime -> Run all` or `Ctrl + F9`) and wait for minutes. You can see the optimal parameters in section **3 Result**. Copy them, replace the parameters in `fsrs4anki_scheduler.js`, and paste them into the custom scheduling of your deck options (require Anki version >= 2.1.55).\n", - "\n", - "**NOTE**: The default output is generated from my review logs. If you find the output is the same as mine, maybe your notebook hasn't run there.\n", - "\n", - "**Contribute to SRS Research**: If you want to share your data with me, please fill this form: https://forms.gle/KaojsBbhMCytaA7h8" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": { - "id": "iqP70_-3EUhi" - }, - "outputs": [], - "source": [ - "# Here are some settings that you need to replace before running this optimizer.\n", - "\n", - "filename = \"collection-2022-09-18@13-21-58.colpkg\"\n", - "# If you upload deck file, replace it with your deck filename. E.g., ALL__Learning.apkg\n", - "# If you upload collection file, replace it with your colpgk filename. E.g., collection-2022-09-18@13-21-58.colpkg\n", - "\n", - "# Replace it with your timezone. I'm in China, so I use Asia/Shanghai.\n", - "# You can find your timezone here: https://gist.github.com/heyalexej/8bf688fd67d7199be4a1682b3eec7568\n", - "timezone = 'Asia/Shanghai'\n", - "\n", - "# Replace it with your Anki's setting in Preferences -> Scheduling.\n", - "next_day_starts_at = 4\n", - "\n", - "# Replace it if you don't want the optimizer to use the review logs before a specific date.\n", - "revlog_start_date = \"2006-10-05\"\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "bLFVNmG2qd06" - }, - "source": [ - "## 1 Build dataset" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "EkzFeKawqgbs" - }, - "source": [ - "### 1.1 Extract Anki collection & deck file" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "KD2js_wEr_Bs", - "outputId": "42653d9e-316e-40bc-bd1d-f3a0e2b246c7" - }, - "outputs": [ + "cells": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "Extract successfully!\n" - ] - } - ], - "source": [ - "import zipfile\n", - "import sqlite3\n", - "import time\n", - "import tqdm\n", - "import pandas as pd\n", - "import numpy as np\n", - "import os\n", - "from datetime import timedelta, datetime\n", - "import matplotlib.pyplot as plt\n", - "import math\n", - "import sys\n", - "import torch\n", - "from torch import nn\n", - "from sklearn.utils import shuffle\n", - "# Extract the collection file or deck file to get the .anki21 database.\n", - "with zipfile.ZipFile(f'./{filename}', 'r') as zip_ref:\n", - " zip_ref.extractall('./')\n", - " print(\"Extract successfully!\")\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "dKpy4VfqGmaL" - }, - "source": [ - "### 1.2 Create time-series feature & analysis\n", - "\n", - "The following code cell will extract the review logs from your Anki collection and preprocess them to a trainset which is saved in `revlog_history.tsv`.\n", - "\n", - "The time-series features are important in optimizing the model's parameters. For more detail, please see my paper: https://www.maimemo.com/paper/\n", - "\n", - "Then it will generate a concise analysis for your review logs. \n", - "\n", - "- The `r_history` is the history of ratings on each review. `3,3,3,1` means that you press `Good, Good, Good, Again`. It only contains the first rating for each card on the review date, i.e., when you press `Again` in review and `Good` in relearning steps 10min later, only `Again` will be recorded.\n", - "- The `avg_interval` is the actual average interval after you rate your cards as the `r_history`. It could be longer than the interval given by Anki's built-in scheduler because you reviewed some overdue cards.\n", - "- The `avg_retention` is the average retention after you press as the `r_history`. `Again` counts as failed recall, and `Hard, Good and Easy` count as successful recall. Retention is the percentage of your successful recall.\n", - "- The `stability` is the estimated memory state variable, which is an approximate interval that leads to 90% retention.\n", - "- The `factor` is `stability / previous stability`.\n", - "- The `group_cnt` is the number of review logs that have the same `r_history`." - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# FSRS4Anki v3.13.0 Optimizer" + ] }, - "id": "J2IIaY3PDaaG", - "outputId": "607916c9-da95-48dd-fdab-6bd83fbbbb40" - }, - "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "revlog.csv saved.\n" - ] + "attachments": {}, + "cell_type": "markdown", + "metadata": { + "id": "lurCmW0Jqz3s" + }, + "source": [ + "[![open in colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/open-spaced-repetition/fsrs4anki/blob/v3.13.0/fsrs4anki_optimizer.ipynb)\n", + "\n", + "↑ Click the above button to open the optimizer on Google Colab.\n", + "\n", + "> If you can't see the button and are located in the Chinese Mainland, please use a proxy or VPN." + ] }, { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "fec0c445154d4182bbff35e17f98e0ef", - "version_major": 2, - "version_minor": 0 + "attachments": {}, + "cell_type": "markdown", + "metadata": { + "id": "wG7bBfGJFbMr" }, - "text/plain": [ - " 0%| | 0/30711 [00:00 Files`, drag and drop your file in the current directory (not the `sample_data` directory). \n", + "\n", + "No need to include media. Need to include scheduling information. \n", + "\n", + "> If you use the latest version of Anki, please check the box `Support older Anki versions (slower/larger files)` when you export.\n", + "\n", + "You can export it via `File -> Export...` or `Ctrl + E` in the main window of Anki.\n", + "\n", + "Then replace the `filename` with yours in the next code cell. And set the `timezone` and `next_day_starts_at` which can be found in your preferences of Anki.\n", + "\n", + "After that, just run all (`Runtime -> Run all` or `Ctrl + F9`) and wait for minutes. You can see the optimal parameters in section **2.3 Result**. Copy them, replace the parameters in `fsrs4anki_scheduler.js`, and paste them into the custom scheduling of your deck options (require Anki version >= 2.1.55).\n", + "\n", + "**NOTE**: The default output is generated from my review logs. If you find the output is the same as mine, maybe your notebook hasn't run there.\n", + "\n", + "**Contribute to SRS Research**: If you want to share your data with me, please fill this form: https://forms.gle/KaojsBbhMCytaA7h8" ] - }, - "metadata": {}, - "output_type": "display_data" }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "Trainset saved.\n" - ] + "cell_type": "code", + "execution_count": 1, + "metadata": { + "id": "iqP70_-3EUhi" + }, + "outputs": [], + "source": [ + "# Here are some settings that you need to replace before running this optimizer.\n", + "\n", + "filename = \"collection-2022-09-18@13-21-58.colpkg\"\n", + "# If you upload deck file, replace it with your deck filename. E.g., ALL__Learning.apkg\n", + "# If you upload collection file, replace it with your colpgk filename. E.g., collection-2022-09-18@13-21-58.colpkg\n", + "\n", + "# Replace it with your timezone. I'm in China, so I use Asia/Shanghai.\n", + "# You can find your timezone here: https://gist.github.com/heyalexej/8bf688fd67d7199be4a1682b3eec7568\n", + "timezone = 'Asia/Shanghai'\n", + "\n", + "# Replace it with your Anki's setting in Preferences -> Scheduling.\n", + "next_day_starts_at = 4\n", + "\n", + "# Replace it if you don't want the optimizer to use the review logs before a specific date.\n", + "revlog_start_date = \"2006-10-05\"\n" + ] }, { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "8c0036ef716b420983b5569158548ae4", - "version_major": 2, - "version_minor": 0 + "cell_type": "markdown", + "metadata": { + "id": "bLFVNmG2qd06" }, - "text/plain": [ - " 0%| | 0/96660 [00:00 0) &\n", + " (df['id'] >= time.mktime(datetime.strptime(revlog_start_date, \"%Y-%m-%d\").timetuple()) * 1000)].copy()\n", + "df['create_date'] = pd.to_datetime(df['cid'] // 1000, unit='s')\n", + "df['create_date'] = df['create_date'].dt.tz_localize(\n", + " 'UTC').dt.tz_convert(timezone)\n", + "df['review_date'] = pd.to_datetime(df['id'] // 1000, unit='s')\n", + "df['review_date'] = df['review_date'].dt.tz_localize(\n", + " 'UTC').dt.tz_convert(timezone)\n", + "df.drop(df[df['review_date'].dt.year < 2006].index, inplace=True)\n", + "df.sort_values(by=['cid', 'id'], inplace=True, ignore_index=True)\n", + "type_sequence = np.array(df['type'])\n", + "time_sequence = np.array(df['time'])\n", + "df.to_csv(\"revlog.csv\", index=False)\n", + "print(\"revlog.csv saved.\")\n", + "df = df[(df['type'] == 0) | (df['type'] == 1)].copy()\n", + "df['real_days'] = df['review_date'] - timedelta(hours=next_day_starts_at)\n", + "df['real_days'] = pd.DatetimeIndex(df['real_days'].dt.floor('D')).to_julian_date()\n", + "df.drop_duplicates(['cid', 'real_days'], keep='first', inplace=True)\n", + "df['delta_t'] = df.real_days.diff()\n", + "df.dropna(inplace=True)\n", + "df['delta_t'] = df['delta_t'].astype(dtype=int)\n", + "df['i'] = 1\n", + "df['r_history'] = \"\"\n", + "df['t_history'] = \"\"\n", + "col_idx = {key: i for i, key in enumerate(df.columns)}\n", + "\n", + "\n", + "# code from https://github.com/L-M-Sherlock/anki_revlog_analysis/blob/main/revlog_analysis.py\n", + "def get_feature(x):\n", + " for idx, log in enumerate(x.itertuples()):\n", + " if idx == 0:\n", + " x.iloc[idx, col_idx['delta_t']] = 0\n", + " if idx == x.shape[0] - 1:\n", + " break\n", + " x.iloc[idx + 1, col_idx['i']] = x.iloc[idx, col_idx['i']] + 1\n", + " x.iloc[idx + 1, col_idx['t_history']] = f\"{x.iloc[idx, col_idx['t_history']]},{x.iloc[idx, col_idx['delta_t']]}\"\n", + " x.iloc[idx + 1, col_idx['r_history']] = f\"{x.iloc[idx, col_idx['r_history']]},{x.iloc[idx, col_idx['r']]}\"\n", + " return x\n", + "\n", + "tqdm.notebook.tqdm.pandas()\n", + "df = df.groupby('cid', as_index=False).progress_apply(get_feature)\n", + "df[\"t_history\"] = df[\"t_history\"].map(lambda x: x[1:] if len(x) > 1 else x)\n", + "df[\"r_history\"] = df[\"r_history\"].map(lambda x: x[1:] if len(x) > 1 else x)\n", + "df.to_csv('revlog_history.tsv', sep=\"\\t\", index=False)\n", + "print(\"Trainset saved.\")\n", + "\n", + "def cal_retention(group: pd.DataFrame) -> pd.DataFrame:\n", + " group['retention'] = round(group['r'].map(lambda x: {1: 0, 2: 1, 3: 1, 4: 1}[x]).mean(), 4)\n", + " group['total_cnt'] = group.shape[0]\n", + " return group\n", + "\n", + "df = df.groupby(by=['r_history', 'delta_t']).progress_apply(cal_retention)\n", + "print(\"Retention calculated.\")\n", + "df = df.drop(columns=['id', 'cid', 'usn', 'ivl', 'last_lvl', 'factor', 'time', 'type', 'create_date', 'review_date', 'real_days', 'r', 't_history'])\n", + "df.drop_duplicates(inplace=True)\n", + "df = df[(df['retention'] < 1) & (df['retention'] > 0)]\n", + "\n", + "def cal_stability(group: pd.DataFrame) -> pd.DataFrame:\n", + " if group['i'].values[0] > 1:\n", + " r_ivl_cnt = sum(group['delta_t'] * group['retention'].map(np.log) * pow(group['total_cnt'], 2))\n", + " ivl_ivl_cnt = sum(group['delta_t'].map(lambda x: x ** 2) * pow(group['total_cnt'], 2))\n", + " group['stability'] = round(np.log(0.9) / (r_ivl_cnt / ivl_ivl_cnt), 1)\n", + " else:\n", + " group['stability'] = 0.0\n", + " group['group_cnt'] = sum(group['total_cnt'])\n", + " group['avg_retention'] = round(sum(group['retention'] * pow(group['total_cnt'], 2)) / sum(pow(group['total_cnt'], 2)), 3)\n", + " group['avg_interval'] = round(sum(group['delta_t'] * pow(group['total_cnt'], 2)) / sum(pow(group['total_cnt'], 2)), 1)\n", + " del group['total_cnt']\n", + " del group['retention']\n", + " del group['delta_t']\n", + " return group\n", + "\n", + "df = df.groupby(by=['r_history']).progress_apply(cal_stability)\n", + "print(\"Stability calculated.\")\n", + "df.reset_index(drop = True, inplace = True)\n", + "df.drop_duplicates(inplace=True)\n", + "df.sort_values(by=['r_history'], inplace=True, ignore_index=True)\n", + "\n", + "if df.shape[0] > 0:\n", + " for idx in tqdm.notebook.tqdm(df.index):\n", + " item = df.loc[idx]\n", + " index = df[(df['i'] == item['i'] + 1) & (df['r_history'].str.startswith(item['r_history']))].index\n", + " df.loc[index, 'last_stability'] = item['stability']\n", + " df['factor'] = round(df['stability'] / df['last_stability'], 2)\n", + " df = df[(df['i'] >= 2) & (df['group_cnt'] >= 100)]\n", + " df['last_recall'] = df['r_history'].map(lambda x: x[-1])\n", + " df = df[df.groupby(['i', 'r_history'])['group_cnt'].transform(max) == df['group_cnt']]\n", + " df.to_csv('./stability_for_analysis.tsv', sep='\\t', index=None)\n", + " print(\"1:again, 2:hard, 3:good, 4:easy\\n\")\n", + " print(df[df['r_history'].str.contains(r'^[1-4][^124]*$', regex=True)][['r_history', 'avg_interval', 'avg_retention', 'stability', 'factor', 'group_cnt']].to_string(index=False))\n", + " print(\"Analysis saved!\")" ] - }, - "metadata": {}, - "output_type": "display_data" }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "1:again, 2:hard, 3:good, 4:easy\n", - "\n", - " r_history avg_interval avg_retention stability factor group_cnt\n", - " 1 1.7 0.765 1.0 inf 7978\n", - " 1,3 3.9 0.876 4.3 4.30 4155\n", - " 1,3,3 8.6 0.883 9.2 2.14 2684\n", - " 1,3,3,3 17.8 0.857 13.8 1.50 1483\n", - " 1,3,3,3,3 37.0 0.812 19.4 1.41 606\n", - "1,3,3,3,3,3 77.1 0.708 23.1 1.19 128\n", - " 2 1.0 0.901 1.1 inf 234\n", - " 2,3 3.2 0.943 6.3 5.73 154\n", - " 3 1.5 0.962 5.4 inf 9070\n", - " 3,3 3.9 0.966 15.2 2.81 6527\n", - " 3,3,3 9.0 0.960 23.5 1.55 5036\n", - " 3,3,3,3 18.6 0.941 35.2 1.50 3052\n", - " 3,3,3,3,3 39.5 0.914 46.9 1.33 1423\n", - "3,3,3,3,3,3 74.3 0.853 55.6 1.19 411\n", - " 4 3.8 0.966 12.1 inf 11436\n", - " 4,3 8.1 0.975 38.9 3.21 7367\n", - " 4,3,3 18.0 0.963 57.7 1.48 5147\n", - " 4,3,3,3 34.0 0.947 77.2 1.34 2525\n", - " 4,3,3,3,3 46.3 0.906 50.1 0.65 452\n", - "Analysis saved!\n" - ] - } - ], - "source": [ - "if os.path.isfile(\"collection.anki21b\"):\n", - " os.remove(\"collection.anki21b\")\n", - " raise Exception(\n", - " \"Please export the file with `support older Anki versions` if you use the latest version of Anki.\")\n", - "elif os.path.isfile(\"collection.anki21\"):\n", - " con = sqlite3.connect(\"collection.anki21\")\n", - "elif os.path.isfile(\"collection.anki2\"):\n", - " con = sqlite3.connect(\"collection.anki2\")\n", - "else:\n", - " raise Exception(\"Collection not exist!\")\n", - "cur = con.cursor()\n", - "res = cur.execute(\"SELECT * FROM revlog\")\n", - "revlog = res.fetchall()\n", - "\n", - "df = pd.DataFrame(revlog)\n", - "df.columns = ['id', 'cid', 'usn', 'r', 'ivl',\n", - " 'last_lvl', 'factor', 'time', 'type']\n", - "df = df[(df['cid'] <= time.time() * 1000) &\n", - " (df['id'] <= time.time() * 1000) &\n", - " (df['r'] > 0) &\n", - " (df['id'] >= time.mktime(datetime.strptime(revlog_start_date, \"%Y-%m-%d\").timetuple()) * 1000)].copy()\n", - "df['create_date'] = pd.to_datetime(df['cid'] // 1000, unit='s')\n", - "df['create_date'] = df['create_date'].dt.tz_localize(\n", - " 'UTC').dt.tz_convert(timezone)\n", - "df['review_date'] = pd.to_datetime(df['id'] // 1000, unit='s')\n", - "df['review_date'] = df['review_date'].dt.tz_localize(\n", - " 'UTC').dt.tz_convert(timezone)\n", - "df.drop(df[df['review_date'].dt.year < 2006].index, inplace=True)\n", - "df.sort_values(by=['cid', 'id'], inplace=True, ignore_index=True)\n", - "type_sequence = np.array(df['type'])\n", - "df.to_csv(\"revlog.csv\", index=False)\n", - "print(\"revlog.csv saved.\")\n", - "df = df[(df['type'] == 0) | (df['type'] == 1)].copy()\n", - "df['real_days'] = df['review_date'] - timedelta(hours=next_day_starts_at)\n", - "df['real_days'] = pd.DatetimeIndex(df['real_days'].dt.floor('D')).to_julian_date()\n", - "df.drop_duplicates(['cid', 'real_days'], keep='first', inplace=True)\n", - "df['delta_t'] = df.real_days.diff()\n", - "df.dropna(inplace=True)\n", - "df['delta_t'] = df['delta_t'].astype(dtype=int)\n", - "df['i'] = 1\n", - "df['r_history'] = \"\"\n", - "df['t_history'] = \"\"\n", - "col_idx = {key: i for i, key in enumerate(df.columns)}\n", - "\n", - "\n", - "# code from https://github.com/L-M-Sherlock/anki_revlog_analysis/blob/main/revlog_analysis.py\n", - "def get_feature(x):\n", - " for idx, log in enumerate(x.itertuples()):\n", - " if idx == 0:\n", - " x.iloc[idx, col_idx['delta_t']] = 0\n", - " if idx == x.shape[0] - 1:\n", - " break\n", - " x.iloc[idx + 1, col_idx['i']] = x.iloc[idx, col_idx['i']] + 1\n", - " x.iloc[idx + 1, col_idx['t_history']] = f\"{x.iloc[idx, col_idx['t_history']]},{x.iloc[idx, col_idx['delta_t']]}\"\n", - " x.iloc[idx + 1, col_idx['r_history']] = f\"{x.iloc[idx, col_idx['r_history']]},{x.iloc[idx, col_idx['r']]}\"\n", - " return x\n", - "\n", - "tqdm.notebook.tqdm.pandas()\n", - "df = df.groupby('cid', as_index=False).progress_apply(get_feature)\n", - "df[\"t_history\"] = df[\"t_history\"].map(lambda x: x[1:] if len(x) > 1 else x)\n", - "df[\"r_history\"] = df[\"r_history\"].map(lambda x: x[1:] if len(x) > 1 else x)\n", - "df.to_csv('revlog_history.tsv', sep=\"\\t\", index=False)\n", - "print(\"Trainset saved.\")\n", - "\n", - "def cal_retention(group: pd.DataFrame) -> pd.DataFrame:\n", - " group['retention'] = round(group['r'].map(lambda x: {1: 0, 2: 1, 3: 1, 4: 1}[x]).mean(), 4)\n", - " group['total_cnt'] = group.shape[0]\n", - " return group\n", - "\n", - "df = df.groupby(by=['r_history', 'delta_t']).progress_apply(cal_retention)\n", - "print(\"Retention calculated.\")\n", - "df = df.drop(columns=['id', 'cid', 'usn', 'ivl', 'last_lvl', 'factor', 'time', 'type', 'create_date', 'review_date', 'real_days', 'r', 't_history'])\n", - "df.drop_duplicates(inplace=True)\n", - "df = df[(df['retention'] < 1) & (df['retention'] > 0)]\n", - "\n", - "def cal_stability(group: pd.DataFrame) -> pd.DataFrame:\n", - " if group['i'].values[0] > 1:\n", - " r_ivl_cnt = sum(group['delta_t'] * group['retention'].map(np.log) * pow(group['total_cnt'], 2))\n", - " ivl_ivl_cnt = sum(group['delta_t'].map(lambda x: x ** 2) * pow(group['total_cnt'], 2))\n", - " group['stability'] = round(np.log(0.9) / (r_ivl_cnt / ivl_ivl_cnt), 1)\n", - " else:\n", - " group['stability'] = 0.0\n", - " group['group_cnt'] = sum(group['total_cnt'])\n", - " group['avg_retention'] = round(sum(group['retention'] * pow(group['total_cnt'], 2)) / sum(pow(group['total_cnt'], 2)), 3)\n", - " group['avg_interval'] = round(sum(group['delta_t'] * pow(group['total_cnt'], 2)) / sum(pow(group['total_cnt'], 2)), 1)\n", - " del group['total_cnt']\n", - " del group['retention']\n", - " del group['delta_t']\n", - " return group\n", - "\n", - "df = df.groupby(by=['r_history']).progress_apply(cal_stability)\n", - "print(\"Stability calculated.\")\n", - "df.reset_index(drop = True, inplace = True)\n", - "df.drop_duplicates(inplace=True)\n", - "df.sort_values(by=['r_history'], inplace=True, ignore_index=True)\n", - "\n", - "if df.shape[0] > 0:\n", - " for idx in tqdm.notebook.tqdm(df.index):\n", - " item = df.loc[idx]\n", - " index = df[(df['i'] == item['i'] + 1) & (df['r_history'].str.startswith(item['r_history']))].index\n", - " df.loc[index, 'last_stability'] = item['stability']\n", - " df['factor'] = round(df['stability'] / df['last_stability'], 2)\n", - " df = df[(df['i'] >= 2) & (df['group_cnt'] >= 100)]\n", - " df['last_recall'] = df['r_history'].map(lambda x: x[-1])\n", - " df = df[df.groupby(['i', 'r_history'])['group_cnt'].transform(max) == df['group_cnt']]\n", - " df.to_csv('./stability_for_analysis.tsv', sep='\\t', index=None)\n", - " print(\"1:again, 2:hard, 3:good, 4:easy\\n\")\n", - " print(df[df['r_history'].str.contains(r'^[1-4][^124]*$', regex=True)][['r_history', 'avg_interval', 'avg_retention', 'stability', 'factor', 'group_cnt']].to_string(index=False))\n", - " print(\"Analysis saved!\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "k_SgzC-auWmu" - }, - "source": [ - "## 2 Optimize parameter" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "WrfBJjqCHEwJ" - }, - "source": [ - "### 2.1 Define the model\n", - "\n", - "FSRS is a time-series model for predicting memory states." - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": { - "id": "tdYp3GMLhTYm" - }, - "outputs": [], - "source": [ - "init_w = [1, 1, 5, -0.5, -0.5, 0.2, 1.4, -0.02, 0.8, 2, -0.2, 0.5, 1]\n", - "\n", - "\n", - "class FSRS(nn.Module):\n", - " def __init__(self, w):\n", - " super(FSRS, self).__init__()\n", - " self.w = nn.Parameter(torch.FloatTensor(w))\n", - " self.zero = torch.FloatTensor([0.0])\n", - "\n", - " def forward(self, x, s, d):\n", - " '''\n", - " :param x: [review interval, review response]\n", - " :param s: stability\n", - " :param d: difficulty\n", - " :return:\n", - " '''\n", - " if torch.equal(s, self.zero):\n", - " # first learn, init memory states\n", - " new_s = self.w[0] + self.w[1] * (x[1] - 1)\n", - " new_d = self.w[2] + self.w[3] * (x[1] - 3)\n", - " new_d = new_d.clamp(1, 10)\n", - " else:\n", - " r = torch.exp(np.log(0.9) * x[0] / s)\n", - " new_d = d + self.w[4] * (x[1] - 3)\n", - " new_d = self.mean_reversion(self.w[2], new_d)\n", - " new_d = new_d.clamp(1, 10)\n", - " # recall\n", - " if x[1] > 1:\n", - " new_s = s * (1 + torch.exp(self.w[6]) *\n", - " (11 - new_d) *\n", - " torch.pow(s, self.w[7]) *\n", - " (torch.exp((1 - r) * self.w[8]) - 1))\n", - " # forget\n", - " else:\n", - " new_s = self.w[9] * torch.pow(new_d, self.w[10]) * torch.pow(\n", - " s, self.w[11]) * torch.exp((1 - r) * self.w[12])\n", - " return new_s, new_d\n", - "\n", - " def loss(self, s, t, r):\n", - " return - (r * np.log(0.9) * t / s + (1 - r) * torch.log(1 - torch.exp(np.log(0.9) * t / s)))\n", - "\n", - " def mean_reversion(self, init, current):\n", - " return self.w[5] * init + (1-self.w[5]) * current\n", - "\n", - "\n", - "class WeightClipper(object):\n", - " def __init__(self, frequency=1):\n", - " self.frequency = frequency\n", - "\n", - " def __call__(self, module):\n", - " if hasattr(module, 'w'):\n", - " w = module.w.data\n", - " w[0] = w[0].clamp(0.1, 10) # initStability\n", - " w[1] = w[1].clamp(0.1, 5) # initStabilityRatingFactor\n", - " w[2] = w[2].clamp(1, 10) # initDifficulty\n", - " w[3] = w[3].clamp(-5, -0.1) # initDifficultyRatingFactor\n", - " w[4] = w[4].clamp(-5, -0.1) # updateDifficultyRatingFactor\n", - " w[5] = w[5].clamp(0, 0.5) # difficultyMeanReversionFactor\n", - " w[6] = w[6].clamp(0, 2) # recallFactor\n", - " w[7] = w[7].clamp(-0.2, -0.01) # recallStabilityDecay\n", - " w[8] = w[8].clamp(0.01, 1.5) # recallRetrievabilityFactor\n", - " w[9] = w[9].clamp(0.5, 5) # forgetFactor\n", - " w[10] = w[10].clamp(-2, -0.01) # forgetDifficultyDecay\n", - " w[11] = w[11].clamp(0.01, 0.9) # forgetStabilityDecay\n", - " w[12] = w[12].clamp(0.01, 2) # forgetRetrievabilityFactor\n", - " module.w.data = w\n", - "\n", - "def lineToTensor(line):\n", - " ivl = line[0].split(',')\n", - " response = line[1].split(',')\n", - " tensor = torch.zeros(len(response), 2)\n", - " for li, response in enumerate(response):\n", - " tensor[li][0] = int(ivl[li])\n", - " tensor[li][1] = int(response)\n", - " return tensor\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "8E1dYfgQLZAC" - }, - "source": [ - "### 2.2 Train the model\n", - "\n", - "The `revlog_history.tsv` generated before will be used for training the FSRS model." - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" + "cell_type": "markdown", + "metadata": { + "id": "k_SgzC-auWmu" + }, + "source": [ + "## 2 Optimize parameter" + ] }, - "id": "Jht0gneShowU", - "outputId": "aaa72b79-b454-483b-d746-df1a353b2c8f" - }, - "outputs": [ { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "9fcd540afa254d6082b0f51400d633a4", - "version_major": 2, - "version_minor": 0 + "cell_type": "markdown", + "metadata": { + "id": "WrfBJjqCHEwJ" }, - "text/plain": [ - " 0%| | 0/225934 [00:00 1:\n", + " new_s = s * (1 + torch.exp(self.w[6]) *\n", + " (11 - new_d) *\n", + " torch.pow(s, self.w[7]) *\n", + " (torch.exp((1 - r) * self.w[8]) - 1))\n", + " # forget\n", + " else:\n", + " new_s = self.w[9] * torch.pow(new_d, self.w[10]) * torch.pow(\n", + " s, self.w[11]) * torch.exp((1 - r) * self.w[12])\n", + " return new_s, new_d\n", + "\n", + " def loss(self, s, t, r):\n", + " return - (r * np.log(0.9) * t / s + (1 - r) * torch.log(1 - torch.exp(np.log(0.9) * t / s)))\n", + "\n", + " def mean_reversion(self, init, current):\n", + " return self.w[5] * init + (1-self.w[5]) * current\n", + "\n", + "\n", + "class WeightClipper(object):\n", + " def __init__(self, frequency=1):\n", + " self.frequency = frequency\n", + "\n", + " def __call__(self, module):\n", + " if hasattr(module, 'w'):\n", + " w = module.w.data\n", + " w[0] = w[0].clamp(0.1, 10)\n", + " w[1] = w[1].clamp(0.1, 5)\n", + " w[2] = w[2].clamp(1, 10)\n", + " w[3] = w[3].clamp(-5, -0.1)\n", + " w[4] = w[4].clamp(-5, -0.1)\n", + " w[5] = w[5].clamp(0, 0.5)\n", + " w[6] = w[6].clamp(0, 2)\n", + " w[7] = w[7].clamp(-0.2, -0.01)\n", + " w[8] = w[8].clamp(0.01, 1.5)\n", + " w[9] = w[9].clamp(0.5, 5)\n", + " w[10] = w[10].clamp(-2, -0.01)\n", + " w[11] = w[11].clamp(0.01, 0.9)\n", + " w[12] = w[12].clamp(0.01, 2)\n", + " module.w.data = w\n", + "\n", + "def lineToTensor(line):\n", + " ivl = line[0].split(',')\n", + " response = line[1].split(',')\n", + " tensor = torch.zeros(len(response), 2)\n", + " for li, response in enumerate(response):\n", + " tensor[li][0] = int(ivl[li])\n", + " tensor[li][1] = int(response)\n", + " return tensor\n" + ] }, { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "8539e0a6cc3f4418859417f762a4b2cd", - "version_major": 2, - "version_minor": 0 + "cell_type": "markdown", + "metadata": { + "id": "8E1dYfgQLZAC" }, - "text/plain": [ - "pre-train: 0%| | 0/28972 [00:00 1) & (dataset['delta_t'] > 0) & (dataset['t_history'].str.count(',0') == 0)]\n", + "dataset['tensor'] = dataset.progress_apply(lambda x: lineToTensor(list(zip([x['t_history']], [x['r_history']]))[0]), axis=1)\n", + "print(\"Tensorized!\")\n", + "\n", + "pre_train_set = dataset[dataset['i'] == 2]\n", + "# pretrain\n", + "epoch_len = len(pre_train_set)\n", + "n_epoch = 1\n", + "pbar = tqdm.notebook.tqdm(desc=\"pre-train\", colour=\"red\", total=epoch_len*n_epoch)\n", + "\n", + "for k in range(n_epoch):\n", + " for i, (_, row) in enumerate(shuffle(pre_train_set, random_state=2022 + k).iterrows()):\n", + " model.train()\n", + " optimizer.zero_grad()\n", + " output_t = [(model.zero, model.zero)]\n", + " for input_t in row['tensor']:\n", + " output_t.append(model(input_t, *output_t[-1]))\n", + " loss = model.loss(output_t[-1][0], row['delta_t'],\n", + " {1: 0, 2: 1, 3: 1, 4: 1}[row['r']])\n", + " if np.isnan(loss.data.item()):\n", + " # Exception Case\n", + " print(row, output_t)\n", + " raise Exception('error case')\n", + " loss.backward()\n", + " optimizer.step()\n", + " model.apply(clipper)\n", + " pbar.update()\n", + "pbar.close()\n", + "for name, param in model.named_parameters():\n", + " print(f\"{name}: {list(map(lambda x: round(float(x), 4),param))}\")\n", + "\n", + "train_set = dataset[dataset['i'] > 2]\n", + "epoch_len = len(train_set)\n", + "n_epoch = 1\n", + "print_len = max(epoch_len*n_epoch // 10, 1)\n", + "pbar = tqdm.notebook.tqdm(desc=\"train\", colour=\"red\", total=epoch_len*n_epoch)\n", + "\n", + "for k in range(n_epoch):\n", + " for i, (_, row) in enumerate(shuffle(train_set, random_state=2022 + k).iterrows()):\n", + " model.train()\n", + " optimizer.zero_grad()\n", + " output_t = [(model.zero, model.zero)]\n", + " for input_t in row['tensor']:\n", + " output_t.append(model(input_t, *output_t[-1]))\n", + " loss = model.loss(output_t[-1][0], row['delta_t'],\n", + " {1: 0, 2: 1, 3: 1, 4: 1}[row['r']])\n", + " if np.isnan(loss.data.item()):\n", + " # Exception Case\n", + " print(row, output_t)\n", + " raise Exception('error case')\n", + " loss.backward()\n", + " for param in model.parameters():\n", + " param.grad[:2] = torch.zeros(2)\n", + " optimizer.step()\n", + " model.apply(clipper)\n", + " pbar.update()\n", + "\n", + " if (k * epoch_len + i) % print_len == 0:\n", + " print(f\"iteration: {k * epoch_len + i + 1}\")\n", + " for name, param in model.named_parameters():\n", + " print(f\"{name}: {list(map(lambda x: round(float(x), 4),param))}\")\n", + "pbar.close()\n", + "\n", + "w = list(map(lambda x: round(float(x), 4), dict(model.named_parameters())['w'].data))\n", + "\n", + "print(\"\\nTraining finished!\")\n" + ] }, { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "cc99188751a544b4a3481473dda3b179", - "version_major": 2, - "version_minor": 0 + "cell_type": "markdown", + "metadata": { + "id": "BZ4S2l7BWfzr" }, - "text/plain": [ - "train: 0%| | 0/196962 [00:00 1) & (dataset['delta_t'] > 0) & (dataset['t_history'].str.count(',0') == 0)]\n", - "dataset['tensor'] = dataset.progress_apply(lambda x: lineToTensor(list(zip([x['t_history']], [x['r_history']]))[0]), axis=1)\n", - "print(\"Tensorized!\")\n", - "\n", - "pre_train_set = dataset[dataset['i'] == 2]\n", - "# pretrain\n", - "epoch_len = len(pre_train_set)\n", - "n_epoch = 1\n", - "pbar = tqdm.notebook.tqdm(desc=\"pre-train\", colour=\"red\", total=epoch_len*n_epoch)\n", - "\n", - "for k in range(n_epoch):\n", - " for i, (_, row) in enumerate(shuffle(pre_train_set, random_state=2022 + k).iterrows()):\n", - " model.train()\n", - " optimizer.zero_grad()\n", - " output_t = [(model.zero, model.zero)]\n", - " for input_t in row['tensor']:\n", - " output_t.append(model(input_t, *output_t[-1]))\n", - " loss = model.loss(output_t[-1][0], row['delta_t'],\n", - " {1: 0, 2: 1, 3: 1, 4: 1}[row['r']])\n", - " if np.isnan(loss.data.item()):\n", - " # Exception Case\n", - " print(row, output_t)\n", - " raise Exception('error case')\n", - " loss.backward()\n", - " optimizer.step()\n", - " model.apply(clipper)\n", - " pbar.update()\n", - "pbar.close()\n", - "for name, param in model.named_parameters():\n", - " print(f\"{name}: {list(map(lambda x: round(float(x), 4),param))}\")\n", - "\n", - "train_set = dataset[dataset['i'] > 2]\n", - "epoch_len = len(train_set)\n", - "n_epoch = 1\n", - "print_len = max(epoch_len*n_epoch // 10, 1)\n", - "pbar = tqdm.notebook.tqdm(desc=\"train\", colour=\"red\", total=epoch_len*n_epoch)\n", - "\n", - "for k in range(n_epoch):\n", - " for i, (_, row) in enumerate(shuffle(train_set, random_state=2022 + k).iterrows()):\n", - " model.train()\n", - " optimizer.zero_grad()\n", - " output_t = [(model.zero, model.zero)]\n", - " for input_t in row['tensor']:\n", - " output_t.append(model(input_t, *output_t[-1]))\n", - " loss = model.loss(output_t[-1][0], row['delta_t'],\n", - " {1: 0, 2: 1, 3: 1, 4: 1}[row['r']])\n", - " if np.isnan(loss.data.item()):\n", - " # Exception Case\n", - " print(row, output_t)\n", - " raise Exception('error case')\n", - " loss.backward()\n", - " for param in model.parameters():\n", - " param.grad[:2] = torch.zeros(2)\n", - " optimizer.step()\n", - " model.apply(clipper)\n", - " pbar.update()\n", - "\n", - " if (k * epoch_len + i) % print_len == 0:\n", - " print(f\"iteration: {k * epoch_len + i + 1}\")\n", - " for name, param in model.named_parameters():\n", - " print(f\"{name}: {list(map(lambda x: round(float(x), 4),param))}\")\n", - "pbar.close()\n", - "\n", - "w = list(map(lambda x: round(float(x), 4), dict(model.named_parameters())['w'].data))\n", - "\n", - "print(\"\\nTraining finished!\")\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "BZ4S2l7BWfzr" - }, - "source": [ - "## 3 Result\n", - "\n", - "Copy the optimal parameters for FSRS for you in the output of next code cell after running." - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" + "cell_type": "code", + "execution_count": 6, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "NTnPSDA2QpUu", + "outputId": "49f487b9-69a7-4e96-b35a-7e027f478fbd" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "var w = [1.014, 2.2933, 4.9444, -1.1646, -0.9942, 0.0227, 1.3911, -0.0498, 0.7376, 1.7016, -0.4742, 0.602, 0.9946];\n" + ] + } + ], + "source": [ + "print(f\"var w = {w};\")" + ] }, - "id": "NTnPSDA2QpUu", - "outputId": "49f487b9-69a7-4e96-b35a-7e027f478fbd" - }, - "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "var w = [1.014, 2.2933, 4.9588, -1.1608, -0.9954, 0.0234, 1.3923, -0.0484, 0.7363, 1.6937, -0.4708, 0.6032, 0.9762];\n" - ] - } - ], - "source": [ - "print(f\"var w = {w};\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "I_zsoDyTaTrT" - }, - "source": [ - "You can see the memory states and intervals generated by FSRS as if you press the good in each review at the due date scheduled by FSRS." - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2.4 Preview" + ] }, - "id": "iws4rtP1WKBT", - "outputId": "890d0287-1a17-4c59-fbbf-ee54d79cd383" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "1:again, 2:hard, 3:good, 4:easy\n", - "\n", - "first rating: 1\n", - "rating history: 1,3,3,3,3,3,3,3,3,3,3\n", - "interval history: 0,1,2,4,9,19,39,79,159,317,624\n", - "difficulty history: 0,7.3,7.2,7.2,7.1,7.1,7.0,7.0,6.9,6.9,6.8\n", - "\n", - "first rating: 2\n", - "rating history: 2,3,3,3,3,3,3,3,3,3,3\n", - "interval history: 0,3,8,19,44,100,223,489,1052,2226,4631\n", - "difficulty history: 0,6.1,6.1,6.1,6.0,6.0,6.0,6.0,5.9,5.9,5.9\n", - "\n", - "first rating: 3\n", - "rating history: 3,3,3,3,3,3,3,3,3,3,3\n", - "interval history: 0,6,16,42,107,265,641,1512,3483,7842,17280\n", - "difficulty history: 0,5.0,5.0,5.0,5.0,5.0,5.0,5.0,5.0,5.0,5.0\n", - "\n", - "first rating: 4\n", - "rating history: 4,3,3,3,3,3,3,3,3,3,3\n", - "interval history: 0,8,24,69,192,517,1348,3409,8376,20022,46625\n", - "difficulty history: 0,3.8,3.8,3.9,3.9,3.9,3.9,4.0,4.0,4.0,4.0\n", - "\n" - ] - } - ], - "source": [ - "requestRetention = 0.9 # recommended setting: 0.8 ~ 0.9\n", - "\n", - "\n", - "class Collection:\n", - " def __init__(self, w):\n", - " self.model = FSRS(w)\n", - "\n", - " def states(self, t_history, r_history):\n", - " with torch.no_grad():\n", - " line_tensor = lineToTensor(list(zip([t_history], [r_history]))[0])\n", - " output_t = [(self.model.zero, self.model.zero)]\n", - " for input_t in line_tensor:\n", - " output_t.append(self.model(input_t, *output_t[-1]))\n", - " return output_t[-1]\n", - "\n", - "\n", - "my_collection = Collection(w)\n", - "print(\"1:again, 2:hard, 3:good, 4:easy\\n\")\n", - "for first_rating in (1,2,3,4):\n", - " print(f'first rating: {first_rating}')\n", - " t_history = \"0\"\n", - " d_history = \"0\"\n", - " r_history = f\"{first_rating}\" # the first rating of the new card\n", - " # print(\"stability, difficulty, lapses\")\n", - " for i in range(10):\n", - " states = my_collection.states(t_history, r_history)\n", - " # print('{0:9.2f} {1:11.2f} {2:7.0f}'.format(\n", - " # *list(map(lambda x: round(float(x), 4), states))))\n", - " next_t = max(round(float(np.log(requestRetention)/np.log(0.9) * states[0])), 1)\n", - " difficulty = round(float(states[1]), 1)\n", - " t_history += f',{int(next_t)}'\n", - " d_history += f',{difficulty}'\n", - " r_history += f\",3\"\n", - " print(f\"rating history: {r_history}\")\n", - " print(f\"interval history: {t_history}\")\n", - " print(f\"difficulty history: {d_history}\")\n", - " print('')\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "You can change the `test_rating_sequence` to see the scheduling intervals in different ratings." - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "(tensor(5.6006), tensor(4.9588))\n", - "(tensor(15.8420), tensor(4.9588))\n", - "(tensor(41.8382), tensor(4.9588))\n", - "(tensor(106.9516), tensor(4.9588))\n", - "(tensor(265.4778), tensor(4.9588))\n", - "(tensor(21.7929), tensor(6.9030))\n", - "(tensor(4.3071), tensor(8.8017))\n", - "(tensor(6.9323), tensor(8.7118))\n", - "(tensor(11.5881), tensor(8.6240))\n", - "(tensor(19.6509), tensor(8.5382))\n", - "(tensor(33.1994), tensor(8.4545))\n", - "(tensor(55.7034), tensor(8.3727))\n", - "rating history: 3,3,3,3,3,1,1,3,3,3,3,3\n", - "interval history: 0,6,16,42,107,265,22,4,7,12,20,33,56\n", - "difficulty history: 0,5.0,5.0,5.0,5.0,5.0,6.9,8.8,8.7,8.6,8.5,8.5,8.4\n" - ] - } - ], - "source": [ - "test_rating_sequence = \"3,3,3,3,3,1,1,3,3,3,3,3\"\n", - "requestRetention = 0.9 # recommended setting: 0.8 ~ 0.9\n", - "easyBonus = 1.3\n", - "hardInterval = 1.2\n", - "\n", - "t_history = \"0\"\n", - "d_history = \"0\"\n", - "for i in range(len(test_rating_sequence.split(','))):\n", - " rating = test_rating_sequence[2*i]\n", - " last_t = int(t_history.split(',')[-1])\n", - " r_history = test_rating_sequence[:2*i+1]\n", - " states = my_collection.states(t_history, r_history)\n", - " print(states)\n", - " next_t = max(1,round(float(np.log(requestRetention)/np.log(0.9) * states[0])))\n", - " if rating == '4':\n", - " next_t = round(next_t * easyBonus)\n", - " elif rating == '2':\n", - " next_t = round(last_t * hardInterval)\n", - " t_history += f',{int(next_t)}'\n", - " difficulty = round(float(states[1]), 1)\n", - " d_history += f',{difficulty}'\n", - "print(f\"rating history: {test_rating_sequence}\")\n", - "print(f\"interval history: {t_history}\")\n", - "print(f\"difficulty history: {d_history}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Predict memory states for each review and save them in `prediction.tsv`.\n", - "\n", - "Meanwhile, it will count the distribution of difficulty." - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "0bc168078d174ab88d48a1a24056bd7f", - "version_major": 2, - "version_minor": 0 + "cell_type": "markdown", + "metadata": { + "id": "I_zsoDyTaTrT" }, - "text/plain": [ - " 0%| | 0/119670 [00:00" + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "f99be96c169b45f8a03d1355e71b679a", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/119670 [00:00 0.1:\n", - " s0_repetitions = repetitions_list[d - 1][s0_index]\n", - " for s_index in range(index_len - 2, -1, -1):\n", - " stability = stability_list[s_index];\n", - " interval = max(1, round(stability * np.log(recall) / np.log(0.9)))\n", - " p_recall = np.power(0.9, interval / stability)\n", - " recall_s = cal_next_recall_stability(stability, p_recall, d, 1)\n", - " forget_d = min(d + d_offset, 10)\n", - " forget_s = cal_next_recall_stability(stability, p_recall, forget_d, 0)\n", - " recall_s_index = min(stability2index(recall_s), index_len - 1)\n", - " forget_s_index = min(max(stability2index(forget_s), 0), index_len - 1)\n", - " recall_repetitions = repetitions_list[d - 1][recall_s_index] + r_repetitions\n", - " forget_repetitions = repetitions_list[forget_d - 1][forget_s_index] + f_repetitions\n", - " exp_repetitions = p_recall * recall_repetitions + (1.0 - p_recall) * forget_repetitions\n", - " if exp_repetitions < repetitions_list[d - 1][s_index]:\n", - " repetitions_list[d - 1][s_index] = exp_repetitions\n", - " diff = s0_repetitions - repetitions_list[d - 1][s0_index]\n", - " df.loc[0 if pd.isnull(df.index.max()) else df.index.max() + 1] = [recall, d, s0_repetitions]\n", - "\n", - "df.sort_values(by=[\"difficulty\", \"retention\"], inplace=True)\n", - "df.to_csv(\"./expected_repetitions.csv\", index=False)\n", - "print(\"expected_repetitions.csv saved.\")\n", - "\n", - "optimal_retention_list = np.zeros(10)\n", - "for d in range(1, d_range+1):\n", - " retention = df[df[\"difficulty\"] == d][\"retention\"]\n", - " repetitions = df[df[\"difficulty\"] == d][\"repetitions\"]\n", - " optimal_retention = retention.iat[repetitions.argmin()]\n", - " optimal_retention_list[d-1] = optimal_retention\n", - " plt.plot(retention, repetitions, label=f\"d={d}, r={optimal_retention}\")\n", - "print(f\"\\n-----suggested retention: {np.inner(difficulty_distribution_padding, optimal_retention_list):.2f}-----\")\n", - "plt.ylabel(\"expected repetitions\")\n", - "plt.xlabel(\"retention\")\n", - "plt.legend()\n", - "plt.grid()\n", - "plt.semilogy()\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Evaluate the model with the log loss. It will compare the log loss between initial model and trained model." - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [ + }, { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "6dd615097589494b9829477848785fcb", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " 0%| | 0/225934 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "base = 1.01\n", + "index_len = 793\n", + "index_offset = 200\n", + "d_range = 10\n", + "d_offset = 1\n", + "r_time = 8\n", + "f_time = 25\n", + "max_time = 200000\n", + "\n", + "type_block = dict()\n", + "type_count = dict()\n", + "type_time = dict()\n", + "last_t = type_sequence[0]\n", + "type_block[last_t] = 1\n", + "type_count[last_t] = 1\n", + "type_time[last_t] = time_sequence[0]\n", + "for i,t in enumerate(type_sequence[1:]):\n", + " type_count[t] = type_count.setdefault(t, 0) + 1\n", + " type_time[t] = type_time.setdefault(t, 0) + time_sequence[i]\n", + " if t != last_t:\n", + " type_block[t] = type_block.setdefault(t, 0) + 1\n", + " last_t = t\n", + "\n", + "r_time = round(type_time[1]/type_count[1]/1000, 1)\n", + "\n", + "if 2 in type_count and 2 in type_block:\n", + " f_time = round(type_time[2]/type_block[2]/1000 + r_time, 1)\n", + "\n", + "print(f\"average time for failed cards: {f_time}s\")\n", + "print(f\"average time for recalled cards: {r_time}s\")\n", + "\n", + "def stability2index(stability):\n", + " return int(round(np.log(stability) / np.log(base)) + index_offset)\n", + "\n", + "def init_stability(d):\n", + " return max(((d - w[2]) / w[3] + 2) * w[1] + w[0], np.power(base, -index_offset))\n", + "\n", + "def cal_next_recall_stability(s, r, d, response):\n", + " if response == 1:\n", + " return s * (1 + np.exp(w[6]) * (11 - d) * np.power(s, w[7]) * (np.exp((1 - r) * w[8]) - 1))\n", + " else:\n", + " return w[9] * np.power(d, w[10]) * np.power(s, w[11]) * np.exp((1 - r) * w[12])\n", + "\n", + "\n", + "stability_list = np.array([np.power(base, i - index_offset) for i in range(index_len)])\n", + "print(f\"terminal stability: {stability_list.max(): .2f}\")\n", + "df = pd.DataFrame(columns=[\"retention\", \"difficulty\", \"time\"])\n", + "\n", + "for percentage in tqdm.notebook.tqdm(range(96, 66, -2)):\n", + " recall = percentage / 100\n", + " time_list = np.zeros((d_range, index_len))\n", + " time_list[:,:-1] = max_time\n", + " for d in range(d_range, 0, -1):\n", + " s0 = init_stability(d)\n", + " s0_index = stability2index(s0)\n", + " diff = max_time\n", + " while diff > 0.1:\n", + " s0_time = time_list[d - 1][s0_index]\n", + " for s_index in range(index_len - 2, -1, -1):\n", + " stability = stability_list[s_index];\n", + " interval = max(1, round(stability * np.log(recall) / np.log(0.9)))\n", + " p_recall = np.power(0.9, interval / stability)\n", + " recall_s = cal_next_recall_stability(stability, p_recall, d, 1)\n", + " forget_d = min(d + d_offset, 10)\n", + " forget_s = cal_next_recall_stability(stability, p_recall, forget_d, 0)\n", + " recall_s_index = min(stability2index(recall_s), index_len - 1)\n", + " forget_s_index = min(max(stability2index(forget_s), 0), index_len - 1)\n", + " recall_time = time_list[d - 1][recall_s_index] + r_time\n", + " forget_time = time_list[forget_d - 1][forget_s_index] + f_time\n", + " exp_time = p_recall * recall_time + (1.0 - p_recall) * forget_time\n", + " if exp_time < time_list[d - 1][s_index]:\n", + " time_list[d - 1][s_index] = exp_time\n", + " diff = s0_time - time_list[d - 1][s0_index]\n", + " df.loc[0 if pd.isnull(df.index.max()) else df.index.max() + 1] = [recall, d, s0_time]\n", + "\n", + "df.sort_values(by=[\"difficulty\", \"retention\"], inplace=True)\n", + "df.to_csv(\"./expected_time.csv\", index=False)\n", + "print(\"expected_time.csv saved.\")\n", + "\n", + "optimal_retention_list = np.zeros(10)\n", + "for d in range(1, d_range+1):\n", + " retention = df[df[\"difficulty\"] == d][\"retention\"]\n", + " time = df[df[\"difficulty\"] == d][\"time\"]\n", + " optimal_retention = retention.iat[time.argmin()]\n", + " optimal_retention_list[d-1] = optimal_retention\n", + " plt.plot(retention, time, label=f\"d={d}, r={optimal_retention}\")\n", + "print(f\"\\n-----suggested retention (experimental): {np.inner(difficulty_distribution_padding, optimal_retention_list):.2f}-----\")\n", + "plt.ylabel(\"expected time (second)\")\n", + "plt.xlabel(\"retention\")\n", + "plt.legend()\n", + "plt.grid()\n", + "plt.semilogy()\n", + "plt.show()" + ] }, { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "47fa4379ac954b9b810af3afdab10cea", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " 0%| | 0/225934 [00:00