Files changed (7) hide show
  1. app.py +39 -36
  2. fsrs4anki_optimizer.ipynb +0 -0
  3. memory_states.py +0 -35
  4. model.py +0 -110
  5. plot.py +0 -99
  6. requirements.txt +1 -7
  7. utilities.py +0 -284
app.py CHANGED
@@ -1,14 +1,11 @@
1
  import gradio as gr
2
  import pytz
3
-
4
  from datetime import datetime
5
-
6
- from utilities import extract, create_time_series_features, train_model, process_personalized_collection, my_loss, \
7
- cleanup
8
  from markdown import instructions_markdown, faq_markdown
9
- from memory_states import get_my_memory_states
10
- from plot import make_plot
11
-
12
 
13
  def get_w_markdown(w):
14
  return f"""
@@ -20,30 +17,38 @@ def get_w_markdown(w):
20
  Check out the Analysis tab for more detailed information."""
21
 
22
 
23
- def anki_optimizer(file, timezone, next_day_starts_at, revlog_start_date, requestRetention, fast_mode,
24
  progress=gr.Progress(track_tqdm=True)):
25
  now = datetime.now()
26
  files = ['prediction.tsv', 'revlog.csv', 'revlog_history.tsv', 'stability_for_analysis.tsv',
27
- 'expected_repetitions.csv']
28
  prefix = now.strftime(f'%Y_%m_%d_%H_%M_%S')
29
-
30
- proj_dir = extract(file, prefix)
31
-
32
- type_sequence, time_sequence, df_out = create_time_series_features(revlog_start_date, timezone, next_day_starts_at, proj_dir)
33
- w, dataset = train_model(proj_dir)
34
- w_markdown = get_w_markdown(w)
35
- cleanup(proj_dir, files)
36
- if fast_mode:
37
- files_out = [proj_dir / file for file in files if (proj_dir / file).exists()]
38
- return w_markdown, None, None, "", files_out
39
-
40
- my_collection, rating_markdown = process_personalized_collection(requestRetention, w)
41
- difficulty_distribution_padding, difficulty_distribution = get_my_memory_states(proj_dir, dataset, my_collection)
42
- fig, suggested_retention_markdown = make_plot(proj_dir, type_sequence, time_sequence, w, difficulty_distribution_padding)
43
- loss_markdown = my_loss(dataset, w)
44
- difficulty_distribution = difficulty_distribution.to_string().replace("\n", "\n\n")
45
- markdown_out = f"""
46
- {suggested_retention_markdown}
 
 
 
 
 
 
 
 
47
 
48
  # Loss Information
49
  {loss_markdown}
@@ -54,12 +59,13 @@ def anki_optimizer(file, timezone, next_day_starts_at, revlog_start_date, reques
54
  # Ratings
55
  {rating_markdown}
56
  """
57
- files_out = [proj_dir / file for file in files if (proj_dir / file).exists()]
58
- return w_markdown, df_out, fig, markdown_out, files_out
 
59
 
60
 
61
  description = """
62
- # FSRS4Anki Optimizer App - v3.14.7
63
  Based on the [tutorial](https://medium.com/@JarrettYe/how-to-use-the-next-generation-spaced-repetition-algorithm-fsrs-on-anki-5a591ca562e2)
64
  of [Jarrett Ye](https://github.com/L-M-Sherlock). This application can give you personalized anki parameters without having to code.
65
 
@@ -74,7 +80,6 @@ with gr.Blocks() as demo:
74
  with gr.Row():
75
  with gr.Column():
76
  file = gr.File(label='Review Logs (Step 1)')
77
- fast_mode_in = gr.Checkbox(value=False, label="Fast Mode (Will just return the optimized weights)")
78
  with gr.Column():
79
  next_day_starts_at = gr.Number(value=4,
80
  label="Next Day Starts at (Step 2)",
@@ -95,15 +100,13 @@ with gr.Blocks() as demo:
95
  with gr.Row():
96
  markdown_output = gr.Markdown()
97
  with gr.Column():
98
- df_output = gr.DataFrame()
99
  plot_output = gr.Plot()
100
  files_output = gr.Files(label="Analysis Files")
101
  with gr.Tab("FAQ"):
102
  gr.Markdown(faq_markdown)
103
 
104
  btn_plot.click(anki_optimizer,
105
- inputs=[file, timezone, next_day_starts_at, revlog_start_date, requestRetention, fast_mode_in],
106
- outputs=[w_output, df_output, plot_output, markdown_output, files_output])
107
 
108
- if __name__ == '__main__':
109
- demo.queue().launch(show_error=True)
 
1
  import gradio as gr
2
  import pytz
3
+ import os
4
  from datetime import datetime
 
 
 
5
  from markdown import instructions_markdown, faq_markdown
6
+ from fsrs4anki_optimizer import Optimizer
7
+ from pathlib import Path
8
+ from utilities import cleanup
9
 
10
  def get_w_markdown(w):
11
  return f"""
 
17
  Check out the Analysis tab for more detailed information."""
18
 
19
 
20
+ def anki_optimizer(file: gr.File, timezone, next_day_starts_at, revlog_start_date, requestRetention,
21
  progress=gr.Progress(track_tqdm=True)):
22
  now = datetime.now()
23
  files = ['prediction.tsv', 'revlog.csv', 'revlog_history.tsv', 'stability_for_analysis.tsv',
24
+ 'expected_time.csv', 'evaluation.tsv']
25
  prefix = now.strftime(f'%Y_%m_%d_%H_%M_%S')
26
+ suffix = file.name.split('/')[-1].replace(".", "_").replace("@", "_")
27
+ proj_dir = Path(f'projects/{prefix}/{suffix}')
28
+ proj_dir.mkdir(parents=True, exist_ok=True)
29
+ print(proj_dir)
30
+ os.chdir(proj_dir)
31
+ proj_dir = Path('.')
32
+ optimizer = Optimizer()
33
+ optimizer.anki_extract(file.name)
34
+ analysis_markdown = optimizer.create_time_series(timezone, revlog_start_date, next_day_starts_at).replace("\n", "\n\n")
35
+ optimizer.define_model()
36
+ optimizer.train()
37
+ w_markdown = get_w_markdown(optimizer.w)
38
+ optimizer.predict_memory_states()
39
+ difficulty_distribution = optimizer.difficulty_distribution.to_string().replace("\n", "\n\n")
40
+ plot_output = optimizer.find_optimal_retention()[0]
41
+ suggested_retention_markdown = f"""# Suggested Retention: `{optimizer.optimal_retention:.2f}`"""
42
+ rating_markdown = optimizer.preview(requestRetention).replace("\n", "\n\n")
43
+ loss_before, loss_after = optimizer.evaluate()
44
+ loss_markdown = f"""
45
+ **Loss before training**: {loss_before}
46
+
47
+ **Loss after training**: {loss_after}
48
+ """
49
+ # optimizer.calibration_graph()
50
+ # optimizer.compare_with_sm2()
51
+ markdown_out = f"""{suggested_retention_markdown}
52
 
53
  # Loss Information
54
  {loss_markdown}
 
59
  # Ratings
60
  {rating_markdown}
61
  """
62
+ files_out = [file for file in files if (proj_dir / file).exists()]
63
+ cleanup(proj_dir, files)
64
+ return w_markdown, markdown_out, plot_output, files_out
65
 
66
 
67
  description = """
68
+ # FSRS4Anki Optimizer App - v3.24.1
69
  Based on the [tutorial](https://medium.com/@JarrettYe/how-to-use-the-next-generation-spaced-repetition-algorithm-fsrs-on-anki-5a591ca562e2)
70
  of [Jarrett Ye](https://github.com/L-M-Sherlock). This application can give you personalized anki parameters without having to code.
71
 
 
80
  with gr.Row():
81
  with gr.Column():
82
  file = gr.File(label='Review Logs (Step 1)')
 
83
  with gr.Column():
84
  next_day_starts_at = gr.Number(value=4,
85
  label="Next Day Starts at (Step 2)",
 
100
  with gr.Row():
101
  markdown_output = gr.Markdown()
102
  with gr.Column():
 
103
  plot_output = gr.Plot()
104
  files_output = gr.Files(label="Analysis Files")
105
  with gr.Tab("FAQ"):
106
  gr.Markdown(faq_markdown)
107
 
108
  btn_plot.click(anki_optimizer,
109
+ inputs=[file, timezone, next_day_starts_at, revlog_start_date, requestRetention],
110
+ outputs=[w_output, markdown_output, plot_output, files_output])
111
 
112
+ demo.queue().launch(show_error=True)
 
fsrs4anki_optimizer.ipynb DELETED
The diff for this file is too large to render. See raw diff
 
memory_states.py DELETED
@@ -1,35 +0,0 @@
1
- import numpy as np
2
- from functools import partial
3
-
4
- import pandas as pd
5
-
6
-
7
- def predict_memory_states(my_collection, group):
8
- states = my_collection.states(*group.name)
9
- group['stability'] = float(states[0])
10
- group['difficulty'] = float(states[1])
11
- group['count'] = len(group)
12
- return pd.DataFrame({
13
- 'r_history': [group.name[1]],
14
- 't_history': [group.name[0]],
15
- 'stability': [round(float(states[0]), 2)],
16
- 'difficulty': [round(float(states[1]), 2)],
17
- 'count': [len(group)]
18
- })
19
-
20
-
21
- def get_my_memory_states(proj_dir, dataset, my_collection):
22
- prediction = dataset.groupby(by=['t_history', 'r_history']).progress_apply(
23
- partial(predict_memory_states, my_collection))
24
- prediction.reset_index(drop=True, inplace=True)
25
- prediction.sort_values(by=['r_history'], inplace=True)
26
- prediction.to_csv(proj_dir / "prediction.tsv", sep='\t', index=None)
27
- # print("prediction.tsv saved.")
28
- prediction['difficulty'] = prediction['difficulty'].map(lambda x: int(round(x)))
29
- difficulty_distribution = prediction.groupby(by=['difficulty'])['count'].sum() / prediction['count'].sum()
30
- # print(difficulty_distribution)
31
- difficulty_distribution_padding = np.zeros(10)
32
- for i in range(10):
33
- if i + 1 in difficulty_distribution.index:
34
- difficulty_distribution_padding[i] = difficulty_distribution.loc[i + 1]
35
- return difficulty_distribution_padding, difficulty_distribution
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model.py DELETED
@@ -1,110 +0,0 @@
1
- import numpy as np
2
- import torch
3
- from torch import nn
4
-
5
- init_w = [1, 1, 5, -0.5, -0.5, 0.2, 1.4, -0.12, 0.8, 2, -0.2, 0.2, 1]
6
- '''
7
- w[0]: initial_stability_for_again_answer
8
- w[1]: initial_stability_step_per_rating
9
- w[2]: initial_difficulty_for_good_answer
10
- w[3]: initial_difficulty_step_per_rating
11
- w[4]: next_difficulty_step_per_rating
12
- w[5]: next_difficulty_reversion_to_mean_speed (used to avoid ease hell)
13
- w[6]: next_stability_factor_after_success
14
- w[7]: next_stability_stabilization_decay_after_success
15
- w[8]: next_stability_retrievability_gain_after_success
16
- w[9]: next_stability_factor_after_failure
17
- w[10]: next_stability_difficulty_decay_after_success
18
- w[11]: next_stability_stability_gain_after_failure
19
- w[12]: next_stability_retrievability_gain_after_failure
20
- For more details about the parameters, please see:
21
- https://github.com/open-spaced-repetition/fsrs4anki/wiki/Free-Spaced-Repetition-Scheduler
22
- '''
23
-
24
-
25
- class FSRS(nn.Module):
26
- def __init__(self, w):
27
- super(FSRS, self).__init__()
28
- self.w = nn.Parameter(torch.FloatTensor(w))
29
- self.zero = torch.FloatTensor([0.0])
30
-
31
- def forward(self, x, s, d):
32
- '''
33
- :param x: [review interval, review response]
34
- :param s: stability
35
- :param d: difficulty
36
- :return:
37
- '''
38
- if torch.equal(s, self.zero):
39
- # first learn, init memory states
40
- new_s = self.w[0] + self.w[1] * (x[1] - 1)
41
- new_d = self.w[2] + self.w[3] * (x[1] - 3)
42
- new_d = new_d.clamp(1, 10)
43
- else:
44
- r = torch.exp(np.log(0.9) * x[0] / s)
45
- new_d = d + self.w[4] * (x[1] - 3)
46
- new_d = self.mean_reversion(self.w[2], new_d)
47
- new_d = new_d.clamp(1, 10)
48
- # recall
49
- if x[1] > 1:
50
- new_s = s * (1 + torch.exp(self.w[6]) *
51
- (11 - new_d) *
52
- torch.pow(s, self.w[7]) *
53
- (torch.exp((1 - r) * self.w[8]) - 1))
54
- # forget
55
- else:
56
- new_s = self.w[9] * torch.pow(new_d, self.w[10]) * torch.pow(
57
- s, self.w[11]) * torch.exp((1 - r) * self.w[12])
58
- return new_s, new_d
59
-
60
- def loss(self, s, t, r):
61
- return - (r * np.log(0.9) * t / s + (1 - r) * torch.log(1 - torch.exp(np.log(0.9) * t / s)))
62
-
63
- def mean_reversion(self, init, current):
64
- return self.w[5] * init + (1-self.w[5]) * current
65
-
66
-
67
- class WeightClipper(object):
68
- def __init__(self, frequency=1):
69
- self.frequency = frequency
70
-
71
- def __call__(self, module):
72
- if hasattr(module, 'w'):
73
- w = module.w.data
74
- w[0] = w[0].clamp(0.1, 10)
75
- w[1] = w[1].clamp(0.1, 5)
76
- w[2] = w[2].clamp(1, 10)
77
- w[3] = w[3].clamp(-5, -0.1)
78
- w[4] = w[4].clamp(-5, -0.1)
79
- w[5] = w[5].clamp(0, 0.5)
80
- w[6] = w[6].clamp(0, 2)
81
- w[7] = w[7].clamp(-0.2, -0.01)
82
- w[8] = w[8].clamp(0.01, 1.5)
83
- w[9] = w[9].clamp(0.5, 5)
84
- w[10] = w[10].clamp(-2, -0.01)
85
- w[11] = w[11].clamp(0.01, 0.9)
86
- w[12] = w[12].clamp(0.01, 2)
87
- module.w.data = w
88
-
89
-
90
- def lineToTensor(line):
91
- ivl = line[0].split(',')
92
- response = line[1].split(',')
93
- tensor = torch.zeros(len(response), 2)
94
- for li, response in enumerate(response):
95
- tensor[li][0] = int(ivl[li])
96
- tensor[li][1] = int(response)
97
- return tensor
98
-
99
-
100
- class Collection:
101
- def __init__(self, w):
102
- self.model = FSRS(w)
103
-
104
- def states(self, t_history, r_history):
105
- with torch.no_grad():
106
- line_tensor = lineToTensor(list(zip([t_history], [r_history]))[0])
107
- output_t = [(self.model.zero, self.model.zero)]
108
- for input_t in line_tensor:
109
- output_t.append(self.model(input_t, *output_t[-1]))
110
- return output_t[-1]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
plot.py DELETED
@@ -1,99 +0,0 @@
1
- from tqdm.auto import trange
2
- import gradio as gr
3
- import pandas as pd
4
- import numpy as np
5
- import plotly.express as px
6
-
7
-
8
- def make_plot(proj_dir, type_sequence, time_sequence, w, difficulty_distribution_padding, progress=gr.Progress(track_tqdm=True)):
9
- base = 1.01
10
- index_len = 793
11
- index_offset = 200
12
- d_range = 10
13
- d_offset = 1
14
- r_time = 8
15
- f_time = 25
16
- max_time = 200000
17
-
18
- type_block = dict()
19
- type_count = dict()
20
- type_time = dict()
21
- last_t = type_sequence[0]
22
- type_block[last_t] = 1
23
- type_count[last_t] = 1
24
- type_time[last_t] = time_sequence[0]
25
- for i,t in enumerate(type_sequence[1:]):
26
- type_count[t] = type_count.setdefault(t, 0) + 1
27
- type_time[t] = type_time.setdefault(t, 0) + time_sequence[i]
28
- if t != last_t:
29
- type_block[t] = type_block.setdefault(t, 0) + 1
30
- last_t = t
31
-
32
- r_time = round(type_time[1]/type_count[1]/1000, 1)
33
-
34
- if 2 in type_count and 2 in type_block:
35
- f_time = round(type_time[2]/type_block[2]/1000 + r_time, 1)
36
-
37
- def stability2index(stability):
38
- return int(round(np.log(stability) / np.log(base)) + index_offset)
39
-
40
- def init_stability(d):
41
- return max(((d - w[2]) / w[3] + 2) * w[1] + w[0], np.power(base, -index_offset))
42
-
43
- def cal_next_recall_stability(s, r, d, response):
44
- if response == 1:
45
- return s * (1 + np.exp(w[6]) * (11 - d) * np.power(s, w[7]) * (np.exp((1 - r) * w[8]) - 1))
46
- else:
47
- return w[9] * np.power(d, w[10]) * np.power(s, w[11]) * np.exp((1 - r) * w[12])
48
-
49
- stability_list = np.array([np.power(base, i - index_offset) for i in range(index_len)])
50
- # print(f"terminal stability: {stability_list.max(): .2f}")
51
- df = pd.DataFrame(columns=["retention", "difficulty", "time"])
52
-
53
- for percentage in trange(96, 66, -2, desc='Time vs Retention plot'):
54
- recall = percentage / 100
55
- time_list = np.zeros((d_range, index_len))
56
- time_list[:,:-1] = max_time
57
- for d in range(d_range, 0, -1):
58
- s0 = init_stability(d)
59
- s0_index = stability2index(s0)
60
- diff = max_time
61
- while diff > 0.1:
62
- s0_time = time_list[d - 1][s0_index]
63
- for s_index in range(index_len - 2, -1, -1):
64
- stability = stability_list[s_index];
65
- interval = max(1, round(stability * np.log(recall) / np.log(0.9)))
66
- p_recall = np.power(0.9, interval / stability)
67
- recall_s = cal_next_recall_stability(stability, p_recall, d, 1)
68
- forget_d = min(d + d_offset, 10)
69
- forget_s = cal_next_recall_stability(stability, p_recall, forget_d, 0)
70
- recall_s_index = min(stability2index(recall_s), index_len - 1)
71
- forget_s_index = min(max(stability2index(forget_s), 0), index_len - 1)
72
- recall_time = time_list[d - 1][recall_s_index] + r_time
73
- forget_time = time_list[forget_d - 1][forget_s_index] + f_time
74
- exp_time = p_recall * recall_time + (1.0 - p_recall) * forget_time
75
- if exp_time < time_list[d - 1][s_index]:
76
- time_list[d - 1][s_index] = exp_time
77
- diff = s0_time - time_list[d - 1][s0_index]
78
- df.loc[0 if pd.isnull(df.index.max()) else df.index.max() + 1] = [recall, d, s0_time]
79
-
80
-
81
- df.sort_values(by=["difficulty", "retention"], inplace=True)
82
- df.to_csv(proj_dir/"expected_time.csv", index=False)
83
- # print("expected_repetitions.csv saved.")
84
-
85
- optimal_retention_list = np.zeros(10)
86
- df2 = pd.DataFrame()
87
- for d in range(1, d_range + 1):
88
- retention = df[df["difficulty"] == d]["retention"]
89
- time = df[df["difficulty"] == d]["time"]
90
- optimal_retention = retention.iat[time.argmin()]
91
- optimal_retention_list[d - 1] = optimal_retention
92
- df2 = df2.append(
93
- pd.DataFrame({'retention': retention, 'expected time': time, 'd': d, 'r': optimal_retention}))
94
-
95
- fig = px.line(df2, x="retention", y="expected time", color='d', log_y=True)
96
-
97
- # print(f"\n-----suggested retention: {np.inner(difficulty_distribution_padding, optimal_retention_list):.2f}-----")
98
- suggested_retention_markdown = f"""# Suggested Retention: `{np.inner(difficulty_distribution_padding, optimal_retention_list):.2f}`"""
99
- return fig, suggested_retention_markdown
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,7 +1 @@
1
- matplotlib==3.4.3
2
- numpy==1.23.3
3
- pandas==1.3.2
4
- scikit_learn==1.1.2
5
- torch==1.9.0
6
- tqdm==4.64.1
7
- plotly==5.13.0
 
1
+ fsrs4anki_optimizer==3.24.1
 
 
 
 
 
 
utilities.py CHANGED
@@ -1,27 +1,9 @@
1
- from functools import partial
2
- import datetime
3
  from zipfile import ZipFile
4
-
5
- import sqlite3
6
- import time
7
-
8
- import gradio as gr
9
- from tqdm.auto import tqdm
10
- import pandas as pd
11
- import numpy as np
12
  import os
13
- from datetime import timedelta, datetime
14
  from pathlib import Path
15
 
16
- import torch
17
- from sklearn.utils import shuffle
18
-
19
- from model import Collection, init_w, FSRS, WeightClipper, lineToTensor
20
-
21
 
22
  # Extract the collection file or deck file to get the .anki21 database.
23
-
24
-
25
  def extract(file, prefix):
26
  proj_dir = Path(f'projects/{prefix}_{file.orig_name.replace(".", "_").replace("@", "_")}')
27
  with ZipFile(file, 'r') as zip_ref:
@@ -29,272 +11,6 @@ def extract(file, prefix):
29
  # print(f"Extracted {file.orig_name} successfully!")
30
  return proj_dir
31
 
32
-
33
- def create_time_series_features(revlog_start_date, timezone, next_day_starts_at, proj_dir,
34
- progress=gr.Progress(track_tqdm=True)):
35
- if os.path.isfile(proj_dir / "collection.anki21b"):
36
- os.remove(proj_dir / "collection.anki21b")
37
- raise gr.Error(
38
- "Please export the file with `support older Anki versions` if you use the latest version of Anki.")
39
- elif os.path.isfile(proj_dir / "collection.anki21"):
40
- con = sqlite3.connect(proj_dir / "collection.anki21")
41
- elif os.path.isfile(proj_dir / "collection.anki2"):
42
- con = sqlite3.connect(proj_dir / "collection.anki2")
43
- else:
44
- raise Exception("Collection not exist!")
45
- cur = con.cursor()
46
- res = cur.execute("SELECT * FROM revlog")
47
- revlog = res.fetchall()
48
-
49
- df = pd.DataFrame(revlog)
50
- df.columns = ['id', 'cid', 'usn', 'r', 'ivl',
51
- 'last_lvl', 'factor', 'time', 'type']
52
- df = df[(df['cid'] <= time.time() * 1000) &
53
- (df['id'] <= time.time() * 1000) &
54
- (df['r'] > 0)].copy()
55
- df['create_date'] = pd.to_datetime(df['cid'] // 1000, unit='s')
56
- df['create_date'] = df['create_date'].dt.tz_localize(
57
- 'UTC').dt.tz_convert(timezone)
58
- df['review_date'] = pd.to_datetime(df['id'] // 1000, unit='s')
59
- df['review_date'] = df['review_date'].dt.tz_localize(
60
- 'UTC').dt.tz_convert(timezone)
61
- df.drop(df[df['review_date'].dt.year < 2006].index, inplace=True)
62
- df.sort_values(by=['cid', 'id'], inplace=True, ignore_index=True)
63
- type_sequence = np.array(df['type'])
64
- time_sequence = np.array(df['time'])
65
- df.to_csv(proj_dir / "revlog.csv", index=False)
66
- # print("revlog.csv saved.")
67
- df = df[df['type'] != 3].copy()
68
- df['real_days'] = df['review_date'] - timedelta(hours=next_day_starts_at)
69
- df['real_days'] = pd.DatetimeIndex(df['real_days'].dt.floor('D', ambiguous='infer', nonexistent='shift_forward')).to_julian_date()
70
- df.drop_duplicates(['cid', 'real_days'], keep='first', inplace=True)
71
- df['delta_t'] = df.real_days.diff()
72
- df.dropna(inplace=True)
73
- df['delta_t'] = df['delta_t'].astype(dtype=int)
74
- df['i'] = 1
75
- df['r_history'] = ""
76
- df['t_history'] = ""
77
- col_idx = {key: i for i, key in enumerate(df.columns)}
78
-
79
- # code from https://github.com/L-M-Sherlock/anki_revlog_analysis/blob/main/revlog_analysis.py
80
- def get_feature(x):
81
- last_kind = None
82
- for idx, log in enumerate(x.itertuples()):
83
- if last_kind is not None and last_kind in (1, 2) and log.type == 0:
84
- return x.iloc[:idx]
85
- last_kind = log.type
86
- if idx == 0:
87
- if log.type != 0:
88
- return x.iloc[:idx]
89
- x.iloc[idx, col_idx['delta_t']] = 0
90
- if idx == x.shape[0] - 1:
91
- break
92
- x.iloc[idx + 1, col_idx['i']] = x.iloc[idx, col_idx['i']] + 1
93
- x.iloc[idx + 1, col_idx[
94
- 't_history']] = f"{x.iloc[idx, col_idx['t_history']]},{x.iloc[idx, col_idx['delta_t']]}"
95
- x.iloc[idx + 1, col_idx['r_history']] = f"{x.iloc[idx, col_idx['r_history']]},{x.iloc[idx, col_idx['r']]}"
96
- return x
97
-
98
- tqdm.pandas(desc='Saving Trainset')
99
- df = df.groupby('cid', as_index=False, group_keys=False).progress_apply(get_feature)
100
- df = df[df['id'] >= time.mktime(datetime.strptime(revlog_start_date, "%Y-%m-%d").timetuple()) * 1000]
101
- df["t_history"] = df["t_history"].map(lambda x: x[1:] if len(x) > 1 else x)
102
- df["r_history"] = df["r_history"].map(lambda x: x[1:] if len(x) > 1 else x)
103
- df.to_csv(proj_dir / 'revlog_history.tsv', sep="\t", index=False)
104
- # print("Trainset saved.")
105
-
106
- def cal_retention(group: pd.DataFrame) -> pd.DataFrame:
107
- group['retention'] = round(group['r'].map(lambda x: {1: 0, 2: 1, 3: 1, 4: 1}[x]).mean(), 4)
108
- group['total_cnt'] = group.shape[0]
109
- return group
110
-
111
- tqdm.pandas(desc='Calculating Retention')
112
- df = df.groupby(by=['r_history', 'delta_t']).progress_apply(cal_retention)
113
- # print("Retention calculated.")
114
- df = df.drop(columns=['id', 'cid', 'usn', 'ivl', 'last_lvl', 'factor', 'time', 'type', 'create_date', 'review_date',
115
- 'real_days', 'r', 't_history'])
116
- df.drop_duplicates(inplace=True)
117
- df['retention'] = df['retention'].map(lambda x: max(min(0.99, x), 0.01))
118
-
119
- def cal_stability(group: pd.DataFrame) -> pd.DataFrame:
120
- group_cnt = sum(group['total_cnt'])
121
- if group_cnt < 10:
122
- return pd.DataFrame()
123
- group['group_cnt'] = group_cnt
124
- if group['i'].values[0] > 1:
125
- r_ivl_cnt = sum(group['delta_t'] * group['retention'].map(np.log) * pow(group['total_cnt'], 2))
126
- ivl_ivl_cnt = sum(group['delta_t'].map(lambda x: x ** 2) * pow(group['total_cnt'], 2))
127
- group['stability'] = round(np.log(0.9) / (r_ivl_cnt / ivl_ivl_cnt), 1)
128
- else:
129
- group['stability'] = 0.0
130
- group['avg_retention'] = round(
131
- sum(group['retention'] * pow(group['total_cnt'], 2)) / sum(pow(group['total_cnt'], 2)), 3)
132
- group['avg_interval'] = round(
133
- sum(group['delta_t'] * pow(group['total_cnt'], 2)) / sum(pow(group['total_cnt'], 2)), 1)
134
- del group['total_cnt']
135
- del group['retention']
136
- del group['delta_t']
137
- return group
138
-
139
- tqdm.pandas(desc='Calculating Stability')
140
- df = df.groupby(by=['r_history'], group_keys=False).progress_apply(cal_stability)
141
- # print("Stability calculated.")
142
- df.reset_index(drop=True, inplace=True)
143
- df.drop_duplicates(inplace=True)
144
- df.sort_values(by=['r_history'], inplace=True, ignore_index=True)
145
-
146
- df_out = pd.DataFrame()
147
- if df.shape[0] > 0:
148
- for idx in tqdm(df.index):
149
- item = df.loc[idx]
150
- index = df[(df['i'] == item['i'] + 1) & (df['r_history'].str.startswith(item['r_history']))].index
151
- df.loc[index, 'last_stability'] = item['stability']
152
- df['factor'] = round(df['stability'] / df['last_stability'], 2)
153
- df = df[(df['i'] >= 2) & (df['group_cnt'] >= 100)]
154
- df['last_recall'] = df['r_history'].map(lambda x: x[-1])
155
- df = df[df.groupby(['i', 'r_history'], group_keys=False)['group_cnt'].transform(max) == df['group_cnt']]
156
- df.to_csv(proj_dir / 'stability_for_analysis.tsv', sep='\t', index=None)
157
- # print("1:again, 2:hard, 3:good, 4:easy\n")
158
- # print(df[df['r_history'].str.contains(r'^[1-4][^124]*$', regex=True)][
159
- # ['r_history', 'avg_interval', 'avg_retention', 'stability', 'factor', 'group_cnt']].to_string(
160
- # index=False))
161
- # print("Analysis saved!")
162
-
163
- df_out = df[df['r_history'].str.contains(r'^[1-4][^124]*$', regex=True)][
164
- ['r_history', 'avg_interval', 'avg_retention', 'stability', 'factor', 'group_cnt']]
165
- return type_sequence, time_sequence, df_out
166
-
167
-
168
- def train_model(proj_dir, progress=gr.Progress(track_tqdm=True)):
169
- model = FSRS(init_w)
170
-
171
- clipper = WeightClipper()
172
- optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)
173
-
174
- dataset = pd.read_csv(proj_dir / "revlog_history.tsv", sep='\t', index_col=None,
175
- dtype={'r_history': str, 't_history': str})
176
- dataset = dataset[(dataset['i'] > 1) & (dataset['delta_t'] > 0) & (dataset['t_history'].str.count(',0') == 0)]
177
-
178
- tqdm.pandas(desc='Tensorizing Line')
179
- dataset['tensor'] = dataset.progress_apply(lambda x: lineToTensor(list(zip([x['t_history']], [x['r_history']]))[0]),
180
- axis=1)
181
- # print("Tensorized!")
182
-
183
- pre_train_set = dataset[dataset['i'] == 2]
184
- # pretrain
185
- epoch_len = len(pre_train_set)
186
- n_epoch = 1
187
- pbar = tqdm(desc="Pre-training", colour="red", total=epoch_len * n_epoch)
188
-
189
- for k in range(n_epoch):
190
- for i, (_, row) in enumerate(shuffle(pre_train_set, random_state=2022 + k).iterrows()):
191
- model.train()
192
- optimizer.zero_grad()
193
- output_t = [(model.zero, model.zero)]
194
- for input_t in row['tensor']:
195
- output_t.append(model(input_t, *output_t[-1]))
196
- loss = model.loss(output_t[-1][0], row['delta_t'],
197
- {1: 0, 2: 1, 3: 1, 4: 1}[row['r']])
198
- if np.isnan(loss.data.item()):
199
- # Exception Case
200
- # print(row, output_t)
201
- raise Exception('error case')
202
- loss.backward()
203
- optimizer.step()
204
- model.apply(clipper)
205
- pbar.update()
206
- pbar.close()
207
- # for name, param in model.named_parameters():
208
- # print(f"{name}: {list(map(lambda x: round(float(x), 4), param))}")
209
-
210
- train_set = dataset[dataset['i'] > 2]
211
- epoch_len = len(train_set)
212
- n_epoch = 1
213
- print_len = max(epoch_len * n_epoch // 10, 1)
214
- pbar = tqdm(desc="Training", total=epoch_len * n_epoch)
215
-
216
- for k in range(n_epoch):
217
- for i, (_, row) in enumerate(shuffle(train_set, random_state=2022 + k).iterrows()):
218
- model.train()
219
- optimizer.zero_grad()
220
- output_t = [(model.zero, model.zero)]
221
- for input_t in row['tensor']:
222
- output_t.append(model(input_t, *output_t[-1]))
223
- loss = model.loss(output_t[-1][0], row['delta_t'],
224
- {1: 0, 2: 1, 3: 1, 4: 1}[row['r']])
225
- if np.isnan(loss.data.item()):
226
- # Exception Case
227
- # print(row, output_t)
228
- raise Exception('error case')
229
- loss.backward()
230
- for param in model.parameters():
231
- param.grad[:2] = torch.zeros(2)
232
- optimizer.step()
233
- model.apply(clipper)
234
- pbar.update()
235
-
236
- # if (k * epoch_len + i) % print_len == 0:
237
- # print(f"iteration: {k * epoch_len + i + 1}")
238
- # for name, param in model.named_parameters():
239
- # print(f"{name}: {list(map(lambda x: round(float(x), 4), param))}")
240
- pbar.close()
241
-
242
- w = list(map(lambda x: round(float(x), 4), dict(model.named_parameters())['w'].data))
243
-
244
- # print("\nTraining finished!")
245
- return w, dataset
246
-
247
-
248
- def process_personalized_collection(requestRetention, w):
249
- my_collection = Collection(w)
250
- rating_dict = {1: "again", 2: "hard", 3: "good", 4: "easy"}
251
- rating_markdown = []
252
- for first_rating in (1, 2, 3, 4):
253
- rating_markdown.append(f'## First Rating: {first_rating} ({rating_dict[first_rating]})')
254
- t_history = "0"
255
- d_history = "0"
256
- r_history = f"{first_rating}" # the first rating of the new card
257
- # print("stability, difficulty, lapses")
258
- for i in range(10):
259
- states = my_collection.states(t_history, r_history)
260
- # print('{0:9.2f} {1:11.2f} {2:7.0f}'.format(
261
- # *list(map(lambda x: round(float(x), 4), states))))
262
- next_t = max(round(float(np.log(requestRetention) / np.log(0.9) * states[0])), 1)
263
- difficulty = round(float(states[1]), 1)
264
- t_history += f',{int(next_t)}'
265
- d_history += f',{difficulty}'
266
- r_history += f",3"
267
- rating_markdown.append(f"**rating history**: {r_history}")
268
- rating_markdown.append(f"**interval history**: {t_history}")
269
- rating_markdown.append(f"**difficulty history**: {d_history}\n")
270
- rating_markdown = '\n\n'.join(rating_markdown)
271
- return my_collection, rating_markdown
272
-
273
-
274
- def log_loss(my_collection, row):
275
- states = my_collection.states(row['t_history'], row['r_history'])
276
- row['log_loss'] = float(my_collection.model.loss(states[0], row['delta_t'], {1: 0, 2: 1, 3: 1, 4: 1}[row['r']]))
277
- return row
278
-
279
-
280
- def my_loss(dataset, w):
281
- my_collection = Collection(init_w)
282
- tqdm.pandas(desc='Calculating Loss before Training')
283
- dataset = dataset.progress_apply(partial(log_loss, my_collection), axis=1)
284
- # print(f"Loss before training: {dataset['log_loss'].mean():.4f}")
285
- loss_before = f"{dataset['log_loss'].mean():.4f}"
286
- my_collection = Collection(w)
287
- tqdm.pandas(desc='Calculating Loss After Training')
288
- dataset = dataset.progress_apply(partial(log_loss, my_collection), axis=1)
289
- # print(f"Loss after training: {dataset['log_loss'].mean():.4f}")
290
- loss_after = f"{dataset['log_loss'].mean():.4f}"
291
- return f"""
292
- **Loss before training**: {loss_before}
293
-
294
- **Loss after training**: {loss_after}
295
- """
296
-
297
-
298
  def cleanup(proj_dir: Path, files):
299
  """
300
  Delete all files in prefix that dont have filenames in files
 
 
 
1
  from zipfile import ZipFile
 
 
 
 
 
 
 
 
2
  import os
 
3
  from pathlib import Path
4
 
 
 
 
 
 
5
 
6
  # Extract the collection file or deck file to get the .anki21 database.
 
 
7
  def extract(file, prefix):
8
  proj_dir = Path(f'projects/{prefix}_{file.orig_name.replace(".", "_").replace("@", "_")}')
9
  with ZipFile(file, 'r') as zip_ref:
 
11
  # print(f"Extracted {file.orig_name} successfully!")
12
  return proj_dir
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  def cleanup(proj_dir: Path, files):
15
  """
16
  Delete all files in prefix that dont have filenames in files