|
from autograd.extend import primitive, defvjp, vspace |
|
from autograd.builtins import tuple |
|
from autograd import make_vjp |
|
|
|
@primitive |
|
def fixed_point(f, a, x0, distance, tol): |
|
_f = f(a) |
|
x, x_prev = _f(x0), x0 |
|
while distance(x, x_prev) > tol: |
|
x, x_prev = _f(x), x |
|
return x |
|
|
|
def fixed_point_vjp(ans, f, a, x0, distance, tol): |
|
def rev_iter(params): |
|
a, x_star, x_star_bar = params |
|
vjp_x, _ = make_vjp(f(a))(x_star) |
|
vs = vspace(x_star) |
|
return lambda g: vs.add(vjp_x(g), x_star_bar) |
|
vjp_a, _ = make_vjp(lambda x, y: f(x)(y))(a, ans) |
|
return lambda g: vjp_a(fixed_point(rev_iter, tuple((a, ans, g)), |
|
vspace(x0).zeros(), distance, tol)) |
|
|
|
defvjp(fixed_point, None, fixed_point_vjp, None) |
|
|