fsrs4anki_app / model.py
JarrettYe's picture
Update to 3.13.0
233f10f
raw history blame
No virus
3.88 kB
import numpy as np
import torch
from torch import nn
init_w = [1, 1, 5, -0.5, -0.5, 0.2, 1.4, -0.12, 0.8, 2, -0.2, 0.2, 1]
'''
w[0]: initial_stability_for_again_answer
w[1]: initial_stability_step_per_rating
w[2]: initial_difficulty_for_good_answer
w[3]: initial_difficulty_step_per_rating
w[4]: next_difficulty_step_per_rating
w[5]: next_difficulty_reversion_to_mean_speed (used to avoid ease hell)
w[6]: next_stability_factor_after_success
w[7]: next_stability_stabilization_decay_after_success
w[8]: next_stability_retrievability_gain_after_success
w[9]: next_stability_factor_after_failure
w[10]: next_stability_difficulty_decay_after_success
w[11]: next_stability_stability_gain_after_failure
w[12]: next_stability_retrievability_gain_after_failure
For more details about the parameters, please see:
https://github.com/open-spaced-repetition/fsrs4anki/wiki/Free-Spaced-Repetition-Scheduler
'''
class FSRS(nn.Module):
def __init__(self, w):
super(FSRS, self).__init__()
self.w = nn.Parameter(torch.FloatTensor(w))
self.zero = torch.FloatTensor([0.0])
def forward(self, x, s, d):
'''
:param x: [review interval, review response]
:param s: stability
:param d: difficulty
:return:
'''
if torch.equal(s, self.zero):
# first learn, init memory states
new_s = self.w[0] + self.w[1] * (x[1] - 1)
new_d = self.w[2] + self.w[3] * (x[1] - 3)
new_d = new_d.clamp(1, 10)
else:
r = torch.exp(np.log(0.9) * x[0] / s)
new_d = d + self.w[4] * (x[1] - 3)
new_d = self.mean_reversion(self.w[2], new_d)
new_d = new_d.clamp(1, 10)
# recall
if x[1] > 1:
new_s = s * (1 + torch.exp(self.w[6]) *
(11 - new_d) *
torch.pow(s, self.w[7]) *
(torch.exp((1 - r) * self.w[8]) - 1))
# forget
else:
new_s = self.w[9] * torch.pow(new_d, self.w[10]) * torch.pow(
s, self.w[11]) * torch.exp((1 - r) * self.w[12])
return new_s, new_d
def loss(self, s, t, r):
return - (r * np.log(0.9) * t / s + (1 - r) * torch.log(1 - torch.exp(np.log(0.9) * t / s)))
def mean_reversion(self, init, current):
return self.w[5] * init + (1-self.w[5]) * current
class WeightClipper(object):
def __init__(self, frequency=1):
self.frequency = frequency
def __call__(self, module):
if hasattr(module, 'w'):
w = module.w.data
w[0] = w[0].clamp(0.1, 10)
w[1] = w[1].clamp(0.1, 5)
w[2] = w[2].clamp(1, 10)
w[3] = w[3].clamp(-5, -0.1)
w[4] = w[4].clamp(-5, -0.1)
w[5] = w[5].clamp(0, 0.5)
w[6] = w[6].clamp(0, 2)
w[7] = w[7].clamp(-0.2, -0.01)
w[8] = w[8].clamp(0.01, 1.5)
w[9] = w[9].clamp(0.5, 5)
w[10] = w[10].clamp(-2, -0.01)
w[11] = w[11].clamp(0.01, 0.9)
w[12] = w[12].clamp(0.01, 2)
module.w.data = w
def lineToTensor(line):
ivl = line[0].split(',')
response = line[1].split(',')
tensor = torch.zeros(len(response), 2)
for li, response in enumerate(response):
tensor[li][0] = int(ivl[li])
tensor[li][1] = int(response)
return tensor
class Collection:
def __init__(self, w):
self.model = FSRS(w)
def states(self, t_history, r_history):
with torch.no_grad():
line_tensor = lineToTensor(list(zip([t_history], [r_history]))[0])
output_t = [(self.model.zero, self.model.zero)]
for input_t in line_tensor:
output_t.append(self.model(input_t, *output_t[-1]))
return output_t[-1]