derek-thomas HF staff commited on
Commit
651b002
1 Parent(s): 392c159

Init commit

Browse files
Files changed (9) hide show
  1. .gitignore +1 -0
  2. app.py +112 -0
  3. fsrs4anki_optimizer.ipynb +0 -0
  4. memory_states.py +35 -0
  5. model.py +93 -0
  6. plot.py +92 -0
  7. projects/.gitkeep +0 -0
  8. requirements.txt +7 -0
  9. utilities.py +296 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ .idea/
app.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pytz
3
+
4
+ from datetime import datetime
5
+
6
+ from utilities import extract, create_time_series_features, train_model, process_personalized_collection, my_loss, \
7
+ cleanup
8
+ from memory_states import get_my_memory_states
9
+ from plot import make_plot
10
+
11
+
12
+ def anki_optimizer(file, timezone, next_day_starts_at, revlog_start_date, requestRetention,
13
+ progress=gr.Progress(track_tqdm=True)):
14
+ now = datetime.now()
15
+ prefix = now.strftime(f'%Y_%m_%d_%H_%M_%S')
16
+ proj_dir = extract(file, prefix)
17
+ type_sequence, df_out = create_time_series_features(revlog_start_date, timezone, next_day_starts_at, proj_dir)
18
+ w, dataset = train_model(proj_dir)
19
+ my_collection, rating_markdown = process_personalized_collection(requestRetention, w)
20
+ difficulty_distribution_padding, difficulty_distribution = get_my_memory_states(proj_dir, dataset, my_collection)
21
+ fig, suggested_retention_markdown = make_plot(proj_dir, type_sequence, w, difficulty_distribution_padding)
22
+ loss_markdown = my_loss(dataset, w)
23
+ difficulty_distribution = difficulty_distribution.to_string().replace("\n", "\n\n")
24
+ markdown_out = f"""
25
+ {suggested_retention_markdown}
26
+
27
+ # Loss Information
28
+ {loss_markdown}
29
+
30
+ # Difficulty Distribution
31
+ {difficulty_distribution}
32
+
33
+ # Ratings
34
+ {rating_markdown}
35
+ """
36
+
37
+ w_markdown = f"""
38
+ # These are the weights for step 5
39
+ `var w = {w};`"""
40
+ files = ['prediction.tsv', 'revlog.csv', 'revlog_history.tsv', 'stability_for_analysis.tsv',
41
+ 'expected_repetitions.csv']
42
+ files_out = [proj_dir / file for file in files]
43
+ cleanup(proj_dir, files)
44
+ return w_markdown, df_out, fig, markdown_out, files_out
45
+
46
+
47
+ with gr.Blocks() as demo:
48
+ with gr.Tab("FSRS4Anki Optimizer"):
49
+ with gr.Box():
50
+ gr.Markdown("""
51
+ Based on the [tutorial](https://medium.com/@JarrettYe/how-to-use-the-next-generation-spaced-repetition-algorithm-fsrs-on-anki-5a591ca562e2) of [Jarrett Ye](https://github.com/L-M-Sherlock)
52
+ Check out the instructions on the next tab.
53
+ """)
54
+ with gr.Box():
55
+ with gr.Row():
56
+ file = gr.File(label='Review Logs')
57
+ timezone = gr.Dropdown(label="Choose your timezone", choices=pytz.all_timezones)
58
+ with gr.Row():
59
+ next_day_starts_at = gr.Number(value=4,
60
+ label="Replace it with your Anki's setting in Preferences -> Scheduling.",
61
+ precision=0)
62
+ with gr.Accordion(label="Advanced Settings", open=False):
63
+ requestRetention = gr.Number(value=.9, label="Recommended to set between 0.8 0.9")
64
+ with gr.Row():
65
+ revlog_start_date = gr.Textbox(value="2006-10-05",
66
+ label="Replace it if you don't want the optimizer to use the review logs before a specific date.")
67
+ with gr.Row():
68
+ btn_plot = gr.Button('Optimize your Anki!')
69
+ with gr.Row():
70
+ w_output = gr.Markdown()
71
+ with gr.Tab("Instructions"):
72
+ with gr.Box():
73
+ gr.Markdown("""
74
+ # How to get personalized Anki parameters
75
+ If you have been using Anki for some time and have accumulated a lot of review logs, you can try this FSRS4Anki
76
+ optimizer app to generate parameters for you.
77
+
78
+ This is based on the amazing work of [Jarrett Ye](https://github.com/L-M-Sherlock)
79
+ # Step 1 - Get the review logs to upload
80
+ 1. Click the gear icon to the right of a deck’s name
81
+ 2. Export
82
+ 3. Check “Include scheduling information” and “Support older Anki versions”
83
+ ![](https://miro.medium.com/v2/resize:fit:1400/format:webp/1*W3Nnfarki2z7Ukyom4kMuw.png)
84
+ 4. Export and upload that file to the app
85
+
86
+ # Step 2 - Get the `next_day_starts_at` parameter
87
+ 1. Open preferences
88
+ 2. Copy the next day starts at value and paste it in the app
89
+ ![](https://miro.medium.com/v2/resize:fit:1072/format:webp/1*qAUb6ry8UxFeCsjnKLXvsQ.png)
90
+
91
+ # Step 3 - Fill in the rest of the settings
92
+
93
+ # Step 4 Click run
94
+
95
+ # Step 5 - Replace the default parameters in FSRS4Anki with the optimized parameters
96
+ ![](https://miro.medium.com/v2/resize:fit:1252/format:webp/1*NM4CR-n7nDk3nQN1Bi30EA.png)
97
+ """)
98
+ with gr.Tab("Analysis"):
99
+ with gr.Row():
100
+ markdown_output = gr.Markdown()
101
+ df_output = gr.DataFrame()
102
+ with gr.Row():
103
+ plot_output = gr.Plot()
104
+ with gr.Row():
105
+ files_output = gr.Files(label="Analysis Files")
106
+
107
+ btn_plot.click(anki_optimizer,
108
+ inputs=[file, timezone, next_day_starts_at, revlog_start_date, requestRetention],
109
+ outputs=[w_output, df_output, plot_output, markdown_output, files_output])
110
+ demo.queue().launch(debug=True, show_error=True)
111
+
112
+ # demo.queue().launch(debug=True)
fsrs4anki_optimizer.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
memory_states.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from functools import partial
3
+
4
+ import pandas as pd
5
+
6
+
7
+ def predict_memory_states(my_collection, group):
8
+ states = my_collection.states(*group.name)
9
+ group['stability'] = float(states[0])
10
+ group['difficulty'] = float(states[1])
11
+ group['count'] = len(group)
12
+ return pd.DataFrame({
13
+ 'r_history': [group.name[1]],
14
+ 't_history': [group.name[0]],
15
+ 'stability': [round(float(states[0]), 2)],
16
+ 'difficulty': [round(float(states[1]), 2)],
17
+ 'count': [len(group)]
18
+ })
19
+
20
+
21
+ def get_my_memory_states(proj_dir, dataset, my_collection):
22
+ prediction = dataset.groupby(by=['t_history', 'r_history']).progress_apply(
23
+ partial(predict_memory_states, my_collection))
24
+ prediction.reset_index(drop=True, inplace=True)
25
+ prediction.sort_values(by=['r_history'], inplace=True)
26
+ prediction.to_csv(proj_dir / "prediction.tsv", sep='\t', index=None)
27
+ print("prediction.tsv saved.")
28
+ prediction['difficulty'] = prediction['difficulty'].map(lambda x: int(round(x)))
29
+ difficulty_distribution = prediction.groupby(by=['difficulty'])['count'].sum() / prediction['count'].sum()
30
+ print(difficulty_distribution)
31
+ difficulty_distribution_padding = np.zeros(10)
32
+ for i in range(10):
33
+ if i + 1 in difficulty_distribution.index:
34
+ difficulty_distribution_padding[i] = difficulty_distribution.loc[i + 1]
35
+ return difficulty_distribution_padding, difficulty_distribution
model.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from torch import nn
4
+
5
+ init_w = [1, 1, 5, -0.5, -0.5, 0.2, 1.4, -0.02, 0.8, 2, -0.2, 0.5, 1]
6
+
7
+
8
+ class FSRS(nn.Module):
9
+ def __init__(self, w):
10
+ super(FSRS, self).__init__()
11
+ self.w = nn.Parameter(torch.FloatTensor(w))
12
+ self.zero = torch.FloatTensor([0.0])
13
+
14
+ def forward(self, x, s, d):
15
+ '''
16
+ :param x: [review interval, review response]
17
+ :param s: stability
18
+ :param d: difficulty
19
+ :return:
20
+ '''
21
+ if torch.equal(s, self.zero):
22
+ # first learn, init memory states
23
+ new_s = self.w[0] + self.w[1] * (x[1] - 1)
24
+ new_d = self.w[2] + self.w[3] * (x[1] - 3)
25
+ new_d = new_d.clamp(1, 10)
26
+ else:
27
+ r = torch.exp(np.log(0.9) * x[0] / s)
28
+ new_d = d + self.w[4] * (x[1] - 3)
29
+ new_d = self.mean_reversion(self.w[2], new_d)
30
+ new_d = new_d.clamp(1, 10)
31
+ # recall
32
+ if x[1] > 1:
33
+ new_s = s * (1 + torch.exp(self.w[6]) *
34
+ (11 - new_d) *
35
+ torch.pow(s, self.w[7]) *
36
+ (torch.exp((1 - r) * self.w[8]) - 1))
37
+ # forget
38
+ else:
39
+ new_s = self.w[9] * torch.pow(new_d, self.w[10]) * torch.pow(
40
+ s, self.w[11]) * torch.exp((1 - r) * self.w[12])
41
+ return new_s, new_d
42
+
43
+ def loss(self, s, t, r):
44
+ return - (r * np.log(0.9) * t / s + (1 - r) * torch.log(1 - torch.exp(np.log(0.9) * t / s)))
45
+
46
+ def mean_reversion(self, init, current):
47
+ return self.w[5] * init + (1-self.w[5]) * current
48
+
49
+
50
+ class WeightClipper(object):
51
+ def __init__(self, frequency=1):
52
+ self.frequency = frequency
53
+
54
+ def __call__(self, module):
55
+ if hasattr(module, 'w'):
56
+ w = module.w.data
57
+ w[0] = w[0].clamp(0.1, 10) # initStability
58
+ w[1] = w[1].clamp(0.1, 5) # initStabilityRatingFactor
59
+ w[2] = w[2].clamp(1, 10) # initDifficulty
60
+ w[3] = w[3].clamp(-5, -0.1) # initDifficultyRatingFactor
61
+ w[4] = w[4].clamp(-5, -0.1) # updateDifficultyRatingFactor
62
+ w[5] = w[5].clamp(0, 0.5) # difficultyMeanReversionFactor
63
+ w[6] = w[6].clamp(0, 2) # recallFactor
64
+ w[7] = w[7].clamp(-0.2, -0.01) # recallStabilityDecay
65
+ w[8] = w[8].clamp(0.01, 1.5) # recallRetrievabilityFactor
66
+ w[9] = w[9].clamp(0.5, 5) # forgetFactor
67
+ w[10] = w[10].clamp(-2, -0.01) # forgetDifficultyDecay
68
+ w[11] = w[11].clamp(0.01, 0.9) # forgetStabilityDecay
69
+ w[12] = w[12].clamp(0.01, 2) # forgetRetrievabilityFactor
70
+ module.w.data = w
71
+
72
+
73
+ def lineToTensor(line):
74
+ ivl = line[0].split(',')
75
+ response = line[1].split(',')
76
+ tensor = torch.zeros(len(response), 2)
77
+ for li, response in enumerate(response):
78
+ tensor[li][0] = int(ivl[li])
79
+ tensor[li][1] = int(response)
80
+ return tensor
81
+
82
+
83
+ class Collection:
84
+ def __init__(self, w):
85
+ self.model = FSRS(w)
86
+
87
+ def states(self, t_history, r_history):
88
+ with torch.no_grad():
89
+ line_tensor = lineToTensor(list(zip([t_history], [r_history]))[0])
90
+ output_t = [(self.model.zero, self.model.zero)]
91
+ for input_t in line_tensor:
92
+ output_t.append(self.model(input_t, *output_t[-1]))
93
+ return output_t[-1]
plot.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tqdm.auto import trange
2
+ import gradio as gr
3
+ import pandas as pd
4
+ import numpy as np
5
+ import plotly.express as px
6
+
7
+
8
+ def make_plot(proj_dir, type_sequence, w, difficulty_distribution_padding, progress=gr.Progress(track_tqdm=True)):
9
+ base = 1.01
10
+ index_len = 800
11
+ index_offset = 150
12
+ d_range = 10
13
+ d_offset = 1
14
+ r_repetitions = 1
15
+ f_repetitions = 2.3
16
+ max_repetitions = 200000
17
+
18
+ type_block = dict()
19
+ type_count = dict()
20
+ last_t = type_sequence[0]
21
+ type_block[last_t] = 1
22
+ type_count[last_t] = 1
23
+ for t in type_sequence[1:]:
24
+ type_count[t] = type_count.setdefault(t, 0) + 1
25
+ if t != last_t:
26
+ type_block[t] = type_block.setdefault(t, 0) + 1
27
+ last_t = t
28
+ if 2 in type_count and 2 in type_block:
29
+ f_repetitions = round(type_count[2] / type_block[2] + 1, 1)
30
+
31
+ def stability2index(stability):
32
+ return int(round(np.log(stability) / np.log(base)) + index_offset)
33
+
34
+ def init_stability(d):
35
+ return max(((d - w[2]) / w[3] + 2) * w[1] + w[0], np.power(base, -index_offset))
36
+
37
+ def cal_next_recall_stability(s, r, d, response):
38
+ if response == 1:
39
+ return s * (1 + np.exp(w[6]) * (11 - d) * np.power(s, w[7]) * (np.exp((1 - r) * w[8]) - 1))
40
+ else:
41
+ return w[9] * np.power(d, w[10]) * np.power(s, w[11]) * np.exp((1 - r) * w[12])
42
+
43
+ stability_list = np.array([np.power(base, i - index_offset) for i in range(index_len)])
44
+ print(f"terminal stability: {stability_list.max(): .2f}")
45
+ df = pd.DataFrame(columns=["retention", "difficulty", "repetitions"])
46
+
47
+ for percentage in trange(96, 70, -2, desc='Repetition vs Retention plot'):
48
+ recall = percentage / 100
49
+ repetitions_list = np.zeros((d_range, index_len))
50
+ repetitions_list[:, :-1] = max_repetitions
51
+ for d in range(d_range, 0, -1):
52
+ s0 = init_stability(d)
53
+ s0_index = stability2index(s0)
54
+ diff = max_repetitions
55
+ while diff > 0.1:
56
+ s0_repetitions = repetitions_list[d - 1][s0_index]
57
+ for s_index in range(index_len - 2, -1, -1):
58
+ stability = stability_list[s_index];
59
+ interval = max(1, round(stability * np.log(recall) / np.log(0.9)))
60
+ p_recall = np.power(0.9, interval / stability)
61
+ recall_s = cal_next_recall_stability(stability, p_recall, d, 1)
62
+ forget_d = min(d + d_offset, 10)
63
+ forget_s = cal_next_recall_stability(stability, p_recall, forget_d, 0)
64
+ recall_s_index = min(stability2index(recall_s), index_len - 1)
65
+ forget_s_index = min(max(stability2index(forget_s), 0), index_len - 1)
66
+ recall_repetitions = repetitions_list[d - 1][recall_s_index] + r_repetitions
67
+ forget_repetitions = repetitions_list[forget_d - 1][forget_s_index] + f_repetitions
68
+ exp_repetitions = p_recall * recall_repetitions + (1.0 - p_recall) * forget_repetitions
69
+ if exp_repetitions < repetitions_list[d - 1][s_index]:
70
+ repetitions_list[d - 1][s_index] = exp_repetitions
71
+ diff = s0_repetitions - repetitions_list[d - 1][s0_index]
72
+ df.loc[0 if pd.isnull(df.index.max()) else df.index.max() + 1] = [recall, d, s0_repetitions]
73
+
74
+ df.sort_values(by=["difficulty", "retention"], inplace=True)
75
+ df.to_csv(proj_dir/"expected_repetitions.csv", index=False)
76
+ print("expected_repetitions.csv saved.")
77
+
78
+ optimal_retention_list = np.zeros(10)
79
+ df2 = pd.DataFrame()
80
+ for d in range(1, d_range + 1):
81
+ retention = df[df["difficulty"] == d]["retention"]
82
+ repetitions = df[df["difficulty"] == d]["repetitions"]
83
+ optimal_retention = retention.iat[repetitions.argmin()]
84
+ optimal_retention_list[d - 1] = optimal_retention
85
+ df2 = df2.append(
86
+ pd.DataFrame({'retention': retention, 'expected repetitions': repetitions, 'd': d, 'r': optimal_retention}))
87
+
88
+ fig = px.line(df2, x="retention", y="expected repetitions", color='d', log_y=True)
89
+
90
+ print(f"\n-----suggested retention: {np.inner(difficulty_distribution_padding, optimal_retention_list):.2f}-----")
91
+ suggested_retention_markdown = f"""# Suggested Retention: `{np.inner(difficulty_distribution_padding, optimal_retention_list):.2f}`"""
92
+ return fig, suggested_retention_markdown
projects/.gitkeep ADDED
File without changes
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ matplotlib==3.4.3
2
+ numpy==1.23.3
3
+ pandas==1.3.2
4
+ scikit_learn==1.1.2
5
+ torch==1.9.0
6
+ tqdm==4.64.1
7
+ plotly==5.13.0
utilities.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ import datetime
3
+ from zipfile import ZipFile
4
+
5
+ import sqlite3
6
+ import time
7
+
8
+ import gradio as gr
9
+ from tqdm.auto import tqdm
10
+ import pandas as pd
11
+ import numpy as np
12
+ import os
13
+ from datetime import timedelta, datetime
14
+ from pathlib import Path
15
+
16
+ import torch
17
+ from sklearn.utils import shuffle
18
+
19
+ from model import Collection, init_w, FSRS, WeightClipper, lineToTensor
20
+
21
+
22
+ # Extract the collection file or deck file to get the .anki21 database.
23
+
24
+
25
+ def extract(file, prefix):
26
+ proj_dir = Path(f'projects/{prefix}_{file.orig_name.replace(".", "_").replace("@", "_")}')
27
+ with ZipFile(file, 'r') as zip_ref:
28
+ zip_ref.extractall(proj_dir)
29
+ print(f"Extracted {file.orig_name} successfully!")
30
+ return proj_dir
31
+
32
+
33
+ def create_time_series_features(revlog_start_date, timezone, next_day_starts_at, proj_dir,
34
+ progress=gr.Progress(track_tqdm=True)):
35
+ if os.path.isfile(proj_dir / "collection.anki21b"):
36
+ os.remove(proj_dir / "collection.anki21b")
37
+ raise Exception(
38
+ "Please export the file with `support older Anki versions` if you use the latest version of Anki.")
39
+ elif os.path.isfile(proj_dir / "collection.anki21"):
40
+ con = sqlite3.connect(proj_dir / "collection.anki21")
41
+ elif os.path.isfile(proj_dir / "collection.anki2"):
42
+ con = sqlite3.connect(proj_dir / "collection.anki2")
43
+ else:
44
+ raise Exception("Collection not exist!")
45
+ cur = con.cursor()
46
+ res = cur.execute("SELECT * FROM revlog")
47
+ revlog = res.fetchall()
48
+
49
+ df = pd.DataFrame(revlog)
50
+ df.columns = ['id', 'cid', 'usn', 'r', 'ivl',
51
+ 'last_lvl', 'factor', 'time', 'type']
52
+ df = df[(df['cid'] <= time.time() * 1000) &
53
+ (df['id'] <= time.time() * 1000) &
54
+ (df['r'] > 0) &
55
+ (df['id'] >= time.mktime(datetime.strptime(revlog_start_date, "%Y-%m-%d").timetuple()) * 1000)].copy()
56
+ df['create_date'] = pd.to_datetime(df['cid'] // 1000, unit='s')
57
+ df['create_date'] = df['create_date'].dt.tz_localize(
58
+ 'UTC').dt.tz_convert(timezone)
59
+ df['review_date'] = pd.to_datetime(df['id'] // 1000, unit='s')
60
+ df['review_date'] = df['review_date'].dt.tz_localize(
61
+ 'UTC').dt.tz_convert(timezone)
62
+ df.drop(df[df['review_date'].dt.year < 2006].index, inplace=True)
63
+ df.sort_values(by=['cid', 'id'], inplace=True, ignore_index=True)
64
+ type_sequence = np.array(df['type'])
65
+ df.to_csv(proj_dir / "revlog.csv", index=False)
66
+ print("revlog.csv saved.")
67
+ df = df[(df['type'] == 0) | (df['type'] == 1)].copy()
68
+ df['real_days'] = df['review_date'] - timedelta(hours=next_day_starts_at)
69
+ df['real_days'] = pd.DatetimeIndex(df['real_days'].dt.floor('D')).to_julian_date()
70
+ df.drop_duplicates(['cid', 'real_days'], keep='first', inplace=True)
71
+ df['delta_t'] = df.real_days.diff()
72
+ df.dropna(inplace=True)
73
+ df['delta_t'] = df['delta_t'].astype(dtype=int)
74
+ df['i'] = 1
75
+ df['r_history'] = ""
76
+ df['t_history'] = ""
77
+ col_idx = {key: i for i, key in enumerate(df.columns)}
78
+
79
+ # code from https://github.com/L-M-Sherlock/anki_revlog_analysis/blob/main/revlog_analysis.py
80
+ def get_feature(x):
81
+ for idx, log in enumerate(x.itertuples()):
82
+ if idx == 0:
83
+ x.iloc[idx, col_idx['delta_t']] = 0
84
+ if idx == x.shape[0] - 1:
85
+ break
86
+ x.iloc[idx + 1, col_idx['i']] = x.iloc[idx, col_idx['i']] + 1
87
+ x.iloc[idx + 1, col_idx[
88
+ 't_history']] = f"{x.iloc[idx, col_idx['t_history']]},{x.iloc[idx, col_idx['delta_t']]}"
89
+ x.iloc[idx + 1, col_idx['r_history']] = f"{x.iloc[idx, col_idx['r_history']]},{x.iloc[idx, col_idx['r']]}"
90
+ return x
91
+
92
+ tqdm.pandas(desc='Saving Trainset')
93
+ df = df.groupby('cid', as_index=False).progress_apply(get_feature)
94
+ df["t_history"] = df["t_history"].map(lambda x: x[1:] if len(x) > 1 else x)
95
+ df["r_history"] = df["r_history"].map(lambda x: x[1:] if len(x) > 1 else x)
96
+ df.to_csv(proj_dir / 'revlog_history.tsv', sep="\t", index=False)
97
+ print("Trainset saved.")
98
+
99
+ def cal_retention(group: pd.DataFrame) -> pd.DataFrame:
100
+ group['retention'] = round(group['r'].map(lambda x: {1: 0, 2: 1, 3: 1, 4: 1}[x]).mean(), 4)
101
+ group['total_cnt'] = group.shape[0]
102
+ return group
103
+
104
+ tqdm.pandas(desc='Calculating Retention')
105
+ df = df.groupby(by=['r_history', 'delta_t']).progress_apply(cal_retention)
106
+ print("Retention calculated.")
107
+ df = df.drop(columns=['id', 'cid', 'usn', 'ivl', 'last_lvl', 'factor', 'time', 'type', 'create_date', 'review_date',
108
+ 'real_days', 'r', 't_history'])
109
+ df.drop_duplicates(inplace=True)
110
+ df = df[(df['retention'] < 1) & (df['retention'] > 0)]
111
+
112
+ def cal_stability(group: pd.DataFrame) -> pd.DataFrame:
113
+ if group['i'].values[0] > 1:
114
+ r_ivl_cnt = sum(group['delta_t'] * group['retention'].map(np.log) * pow(group['total_cnt'], 2))
115
+ ivl_ivl_cnt = sum(group['delta_t'].map(lambda x: x ** 2) * pow(group['total_cnt'], 2))
116
+ group['stability'] = round(np.log(0.9) / (r_ivl_cnt / ivl_ivl_cnt), 1)
117
+ else:
118
+ group['stability'] = 0.0
119
+ group['group_cnt'] = sum(group['total_cnt'])
120
+ group['avg_retention'] = round(
121
+ sum(group['retention'] * pow(group['total_cnt'], 2)) / sum(pow(group['total_cnt'], 2)), 3)
122
+ group['avg_interval'] = round(
123
+ sum(group['delta_t'] * pow(group['total_cnt'], 2)) / sum(pow(group['total_cnt'], 2)), 1)
124
+ del group['total_cnt']
125
+ del group['retention']
126
+ del group['delta_t']
127
+ return group
128
+
129
+ tqdm.pandas(desc='Calculating Stability')
130
+ df = df.groupby(by=['r_history']).progress_apply(cal_stability)
131
+ print("Stability calculated.")
132
+ df.reset_index(drop=True, inplace=True)
133
+ df.drop_duplicates(inplace=True)
134
+ df.sort_values(by=['r_history'], inplace=True, ignore_index=True)
135
+
136
+ df_out = pd.DataFrame()
137
+ if df.shape[0] > 0:
138
+ for idx in tqdm(df.index):
139
+ item = df.loc[idx]
140
+ index = df[(df['i'] == item['i'] + 1) & (df['r_history'].str.startswith(item['r_history']))].index
141
+ df.loc[index, 'last_stability'] = item['stability']
142
+ df['factor'] = round(df['stability'] / df['last_stability'], 2)
143
+ df = df[(df['i'] >= 2) & (df['group_cnt'] >= 100)]
144
+ df['last_recall'] = df['r_history'].map(lambda x: x[-1])
145
+ df = df[df.groupby(['i', 'r_history'])['group_cnt'].transform(max) == df['group_cnt']]
146
+ df.to_csv(proj_dir / 'stability_for_analysis.tsv', sep='\t', index=None)
147
+ print("1:again, 2:hard, 3:good, 4:easy\n")
148
+ print(df[df['r_history'].str.contains(r'^[1-4][^124]*$', regex=True)][
149
+ ['r_history', 'avg_interval', 'avg_retention', 'stability', 'factor', 'group_cnt']].to_string(
150
+ index=False))
151
+ print("Analysis saved!")
152
+
153
+ df_out = df[df['r_history'].str.contains(r'^[1-4][^124]*$', regex=True)][
154
+ ['r_history', 'avg_interval', 'avg_retention', 'stability', 'factor', 'group_cnt']]
155
+ return type_sequence, df_out
156
+
157
+
158
+ def train_model(proj_dir, progress=gr.Progress(track_tqdm=True)):
159
+ model = FSRS(init_w)
160
+
161
+ clipper = WeightClipper()
162
+ optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)
163
+
164
+ dataset = pd.read_csv(proj_dir / "revlog_history.tsv", sep='\t', index_col=None,
165
+ dtype={'r_history': str, 't_history': str})
166
+ dataset = dataset[(dataset['i'] > 1) & (dataset['delta_t'] > 0) & (dataset['t_history'].str.count(',0') == 0)]
167
+
168
+ tqdm.pandas(desc='Tensorizing Line')
169
+ dataset['tensor'] = dataset.progress_apply(lambda x: lineToTensor(list(zip([x['t_history']], [x['r_history']]))[0]),
170
+ axis=1)
171
+ print("Tensorized!")
172
+
173
+ pre_train_set = dataset[dataset['i'] == 2]
174
+ # pretrain
175
+ epoch_len = len(pre_train_set)
176
+ n_epoch = 1
177
+ pbar = tqdm(desc="Pre-training", colour="red", total=epoch_len * n_epoch)
178
+
179
+ for k in range(n_epoch):
180
+ for i, (_, row) in enumerate(shuffle(pre_train_set, random_state=2022 + k).iterrows()):
181
+ model.train()
182
+ optimizer.zero_grad()
183
+ output_t = [(model.zero, model.zero)]
184
+ for input_t in row['tensor']:
185
+ output_t.append(model(input_t, *output_t[-1]))
186
+ loss = model.loss(output_t[-1][0], row['delta_t'],
187
+ {1: 0, 2: 1, 3: 1, 4: 1}[row['r']])
188
+ if np.isnan(loss.data.item()):
189
+ # Exception Case
190
+ print(row, output_t)
191
+ raise Exception('error case')
192
+ loss.backward()
193
+ optimizer.step()
194
+ model.apply(clipper)
195
+ pbar.update()
196
+ pbar.close()
197
+ for name, param in model.named_parameters():
198
+ print(f"{name}: {list(map(lambda x: round(float(x), 4), param))}")
199
+
200
+ train_set = dataset[dataset['i'] > 2]
201
+ epoch_len = len(train_set)
202
+ n_epoch = 1
203
+ print_len = max(epoch_len * n_epoch // 10, 1)
204
+ pbar = tqdm(desc="Training", total=epoch_len * n_epoch)
205
+
206
+ for k in range(n_epoch):
207
+ for i, (_, row) in enumerate(shuffle(train_set, random_state=2022 + k).iterrows()):
208
+ model.train()
209
+ optimizer.zero_grad()
210
+ output_t = [(model.zero, model.zero)]
211
+ for input_t in row['tensor']:
212
+ output_t.append(model(input_t, *output_t[-1]))
213
+ loss = model.loss(output_t[-1][0], row['delta_t'],
214
+ {1: 0, 2: 1, 3: 1, 4: 1}[row['r']])
215
+ if np.isnan(loss.data.item()):
216
+ # Exception Case
217
+ print(row, output_t)
218
+ raise Exception('error case')
219
+ loss.backward()
220
+ for param in model.parameters():
221
+ param.grad[:2] = torch.zeros(2)
222
+ optimizer.step()
223
+ model.apply(clipper)
224
+ pbar.update()
225
+
226
+ if (k * epoch_len + i) % print_len == 0:
227
+ print(f"iteration: {k * epoch_len + i + 1}")
228
+ for name, param in model.named_parameters():
229
+ print(f"{name}: {list(map(lambda x: round(float(x), 4), param))}")
230
+ pbar.close()
231
+
232
+ w = list(map(lambda x: round(float(x), 4), dict(model.named_parameters())['w'].data))
233
+
234
+ print("\nTraining finished!")
235
+ return w, dataset
236
+
237
+
238
+ def process_personalized_collection(requestRetention, w):
239
+ my_collection = Collection(w)
240
+ rating_dict = {1: "again", 2: "hard", 3: "good", 4: "easy"}
241
+ rating_markdown = []
242
+ for first_rating in (1, 2, 3, 4):
243
+ rating_markdown.append(f'## First Rating: {first_rating} ({rating_dict[first_rating]})')
244
+ t_history = "0"
245
+ d_history = "0"
246
+ r_history = f"{first_rating}" # the first rating of the new card
247
+ # print("stability, difficulty, lapses")
248
+ for i in range(10):
249
+ states = my_collection.states(t_history, r_history)
250
+ # print('{0:9.2f} {1:11.2f} {2:7.0f}'.format(
251
+ # *list(map(lambda x: round(float(x), 4), states))))
252
+ next_t = max(round(float(np.log(requestRetention) / np.log(0.9) * states[0])), 1)
253
+ difficulty = round(float(states[1]), 1)
254
+ t_history += f',{int(next_t)}'
255
+ d_history += f',{difficulty}'
256
+ r_history += f",3"
257
+ rating_markdown.append(f"*rating history*: {r_history}")
258
+ rating_markdown.append(f"*interval history*: {t_history}")
259
+ rating_markdown.append(f"*difficulty history*: {d_history}\n")
260
+ rating_markdown = '\n\n'.join(rating_markdown)
261
+ return my_collection, rating_markdown
262
+
263
+
264
+ def log_loss(my_collection, row):
265
+ states = my_collection.states(row['t_history'], row['r_history'])
266
+ row['log_loss'] = float(my_collection.model.loss(states[0], row['delta_t'], {1: 0, 2: 1, 3: 1, 4: 1}[row['r']]))
267
+ return row
268
+
269
+
270
+ def my_loss(dataset, w):
271
+ my_collection = Collection(init_w)
272
+ tqdm.pandas(desc='Calculating Loss before Training')
273
+ dataset = dataset.progress_apply(partial(log_loss, my_collection), axis=1)
274
+ print(f"Loss before training: {dataset['log_loss'].mean():.4f}")
275
+ my_collection = Collection(w)
276
+ tqdm.pandas(desc='Calculating Loss After Training')
277
+ dataset = dataset.progress_apply(partial(log_loss, my_collection), axis=1)
278
+ print(f"Loss after training: {dataset['log_loss'].mean():.4f}")
279
+ return f"""
280
+ *Loss before training*: {dataset['log_loss'].mean():.4f}
281
+
282
+ *Loss after training*: {dataset['log_loss'].mean():.4f}
283
+ """
284
+
285
+
286
+ def cleanup(proj_dir: Path, files):
287
+ """
288
+ Delete all files in prefix that dont have filenames in files
289
+ :param proj_dir:
290
+ :param files:
291
+ :return:
292
+ """
293
+ for file in proj_dir.glob('*'):
294
+ if file.name not in files:
295
+ os.remove(file)
296
+