File size: 17,260 Bytes
9dd3461 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 |
import torch
from functools import reduce
from .optimizer import Optimizer
__all__ = ['LBFGS']
def _cubic_interpolate(x1, f1, g1, x2, f2, g2, bounds=None):
# ported from https://github.com/torch/optim/blob/master/polyinterp.lua
# Compute bounds of interpolation area
if bounds is not None:
xmin_bound, xmax_bound = bounds
else:
xmin_bound, xmax_bound = (x1, x2) if x1 <= x2 else (x2, x1)
# Code for most common case: cubic interpolation of 2 points
# w/ function and derivative values for both
# Solution in this case (where x2 is the farthest point):
# d1 = g1 + g2 - 3*(f1-f2)/(x1-x2);
# d2 = sqrt(d1^2 - g1*g2);
# min_pos = x2 - (x2 - x1)*((g2 + d2 - d1)/(g2 - g1 + 2*d2));
# t_new = min(max(min_pos,xmin_bound),xmax_bound);
d1 = g1 + g2 - 3 * (f1 - f2) / (x1 - x2)
d2_square = d1**2 - g1 * g2
if d2_square >= 0:
d2 = d2_square.sqrt()
if x1 <= x2:
min_pos = x2 - (x2 - x1) * ((g2 + d2 - d1) / (g2 - g1 + 2 * d2))
else:
min_pos = x1 - (x1 - x2) * ((g1 + d2 - d1) / (g1 - g2 + 2 * d2))
return min(max(min_pos, xmin_bound), xmax_bound)
else:
return (xmin_bound + xmax_bound) / 2.
def _strong_wolfe(obj_func,
x,
t,
d,
f,
g,
gtd,
c1=1e-4,
c2=0.9,
tolerance_change=1e-9,
max_ls=25):
# ported from https://github.com/torch/optim/blob/master/lswolfe.lua
d_norm = d.abs().max()
g = g.clone(memory_format=torch.contiguous_format)
# evaluate objective and gradient using initial step
f_new, g_new = obj_func(x, t, d)
ls_func_evals = 1
gtd_new = g_new.dot(d)
# bracket an interval containing a point satisfying the Wolfe criteria
t_prev, f_prev, g_prev, gtd_prev = 0, f, g, gtd
done = False
ls_iter = 0
while ls_iter < max_ls:
# check conditions
if f_new > (f + c1 * t * gtd) or (ls_iter > 1 and f_new >= f_prev):
bracket = [t_prev, t]
bracket_f = [f_prev, f_new]
bracket_g = [g_prev, g_new.clone(memory_format=torch.contiguous_format)]
bracket_gtd = [gtd_prev, gtd_new]
break
if abs(gtd_new) <= -c2 * gtd:
bracket = [t]
bracket_f = [f_new]
bracket_g = [g_new]
done = True
break
if gtd_new >= 0:
bracket = [t_prev, t]
bracket_f = [f_prev, f_new]
bracket_g = [g_prev, g_new.clone(memory_format=torch.contiguous_format)]
bracket_gtd = [gtd_prev, gtd_new]
break
# interpolate
min_step = t + 0.01 * (t - t_prev)
max_step = t * 10
tmp = t
t = _cubic_interpolate(
t_prev,
f_prev,
gtd_prev,
t,
f_new,
gtd_new,
bounds=(min_step, max_step))
# next step
t_prev = tmp
f_prev = f_new
g_prev = g_new.clone(memory_format=torch.contiguous_format)
gtd_prev = gtd_new
f_new, g_new = obj_func(x, t, d)
ls_func_evals += 1
gtd_new = g_new.dot(d)
ls_iter += 1
# reached max number of iterations?
if ls_iter == max_ls:
bracket = [0, t]
bracket_f = [f, f_new]
bracket_g = [g, g_new]
# zoom phase: we now have a point satisfying the criteria, or
# a bracket around it. We refine the bracket until we find the
# exact point satisfying the criteria
insuf_progress = False
# find high and low points in bracket
low_pos, high_pos = (0, 1) if bracket_f[0] <= bracket_f[-1] else (1, 0)
while not done and ls_iter < max_ls:
# line-search bracket is so small
if abs(bracket[1] - bracket[0]) * d_norm < tolerance_change:
break
# compute new trial value
t = _cubic_interpolate(bracket[0], bracket_f[0], bracket_gtd[0],
bracket[1], bracket_f[1], bracket_gtd[1])
# test that we are making sufficient progress:
# in case `t` is so close to boundary, we mark that we are making
# insufficient progress, and if
# + we have made insufficient progress in the last step, or
# + `t` is at one of the boundary,
# we will move `t` to a position which is `0.1 * len(bracket)`
# away from the nearest boundary point.
eps = 0.1 * (max(bracket) - min(bracket))
if min(max(bracket) - t, t - min(bracket)) < eps:
# interpolation close to boundary
if insuf_progress or t >= max(bracket) or t <= min(bracket):
# evaluate at 0.1 away from boundary
if abs(t - max(bracket)) < abs(t - min(bracket)):
t = max(bracket) - eps
else:
t = min(bracket) + eps
insuf_progress = False
else:
insuf_progress = True
else:
insuf_progress = False
# Evaluate new point
f_new, g_new = obj_func(x, t, d)
ls_func_evals += 1
gtd_new = g_new.dot(d)
ls_iter += 1
if f_new > (f + c1 * t * gtd) or f_new >= bracket_f[low_pos]:
# Armijo condition not satisfied or not lower than lowest point
bracket[high_pos] = t
bracket_f[high_pos] = f_new
bracket_g[high_pos] = g_new.clone(memory_format=torch.contiguous_format)
bracket_gtd[high_pos] = gtd_new
low_pos, high_pos = (0, 1) if bracket_f[0] <= bracket_f[1] else (1, 0)
else:
if abs(gtd_new) <= -c2 * gtd:
# Wolfe conditions satisfied
done = True
elif gtd_new * (bracket[high_pos] - bracket[low_pos]) >= 0:
# old high becomes new low
bracket[high_pos] = bracket[low_pos]
bracket_f[high_pos] = bracket_f[low_pos]
bracket_g[high_pos] = bracket_g[low_pos]
bracket_gtd[high_pos] = bracket_gtd[low_pos]
# new point becomes new low
bracket[low_pos] = t
bracket_f[low_pos] = f_new
bracket_g[low_pos] = g_new.clone(memory_format=torch.contiguous_format)
bracket_gtd[low_pos] = gtd_new
# return stuff
t = bracket[low_pos]
f_new = bracket_f[low_pos]
g_new = bracket_g[low_pos]
return f_new, g_new, t, ls_func_evals
class LBFGS(Optimizer):
"""Implements L-BFGS algorithm, heavily inspired by `minFunc
<https://www.cs.ubc.ca/~schmidtm/Software/minFunc.html>`_.
.. warning::
This optimizer doesn't support per-parameter options and parameter
groups (there can be only one).
.. warning::
Right now all parameters have to be on a single device. This will be
improved in the future.
.. note::
This is a very memory intensive optimizer (it requires additional
``param_bytes * (history_size + 1)`` bytes). If it doesn't fit in memory
try reducing the history size, or use a different algorithm.
Args:
lr (float): learning rate (default: 1)
max_iter (int): maximal number of iterations per optimization step
(default: 20)
max_eval (int): maximal number of function evaluations per optimization
step (default: max_iter * 1.25).
tolerance_grad (float): termination tolerance on first order optimality
(default: 1e-5).
tolerance_change (float): termination tolerance on function
value/parameter changes (default: 1e-9).
history_size (int): update history size (default: 100).
line_search_fn (str): either 'strong_wolfe' or None (default: None).
"""
def __init__(self,
params,
lr=1,
max_iter=20,
max_eval=None,
tolerance_grad=1e-7,
tolerance_change=1e-9,
history_size=100,
line_search_fn=None):
if max_eval is None:
max_eval = max_iter * 5 // 4
defaults = dict(
lr=lr,
max_iter=max_iter,
max_eval=max_eval,
tolerance_grad=tolerance_grad,
tolerance_change=tolerance_change,
history_size=history_size,
line_search_fn=line_search_fn)
super(LBFGS, self).__init__(params, defaults)
if len(self.param_groups) != 1:
raise ValueError("LBFGS doesn't support per-parameter options "
"(parameter groups)")
self._params = self.param_groups[0]['params']
self._numel_cache = None
def _numel(self):
if self._numel_cache is None:
self._numel_cache = reduce(lambda total, p: total + p.numel(), self._params, 0)
return self._numel_cache
def _gather_flat_grad(self):
views = []
for p in self._params:
if p.grad is None:
view = p.new(p.numel()).zero_()
elif p.grad.is_sparse:
view = p.grad.to_dense().view(-1)
else:
view = p.grad.view(-1)
views.append(view)
return torch.cat(views, 0)
def _add_grad(self, step_size, update):
offset = 0
for p in self._params:
numel = p.numel()
# view as to avoid deprecated pointwise semantics
p.add_(update[offset:offset + numel].view_as(p), alpha=step_size)
offset += numel
assert offset == self._numel()
def _clone_param(self):
return [p.clone(memory_format=torch.contiguous_format) for p in self._params]
def _set_param(self, params_data):
for p, pdata in zip(self._params, params_data):
p.copy_(pdata)
def _directional_evaluate(self, closure, x, t, d):
self._add_grad(t, d)
loss = float(closure())
flat_grad = self._gather_flat_grad()
self._set_param(x)
return loss, flat_grad
@torch.no_grad()
def step(self, closure):
"""Performs a single optimization step.
Args:
closure (Callable): A closure that reevaluates the model
and returns the loss.
"""
assert len(self.param_groups) == 1
# Make sure the closure is always called with grad enabled
closure = torch.enable_grad()(closure)
group = self.param_groups[0]
lr = group['lr']
max_iter = group['max_iter']
max_eval = group['max_eval']
tolerance_grad = group['tolerance_grad']
tolerance_change = group['tolerance_change']
line_search_fn = group['line_search_fn']
history_size = group['history_size']
# NOTE: LBFGS has only global state, but we register it as state for
# the first param, because this helps with casting in load_state_dict
state = self.state[self._params[0]]
state.setdefault('func_evals', 0)
state.setdefault('n_iter', 0)
# evaluate initial f(x) and df/dx
orig_loss = closure()
loss = float(orig_loss)
current_evals = 1
state['func_evals'] += 1
flat_grad = self._gather_flat_grad()
opt_cond = flat_grad.abs().max() <= tolerance_grad
# optimal condition
if opt_cond:
return orig_loss
# tensors cached in state (for tracing)
d = state.get('d')
t = state.get('t')
old_dirs = state.get('old_dirs')
old_stps = state.get('old_stps')
ro = state.get('ro')
H_diag = state.get('H_diag')
prev_flat_grad = state.get('prev_flat_grad')
prev_loss = state.get('prev_loss')
n_iter = 0
# optimize for a max of max_iter iterations
while n_iter < max_iter:
# keep track of nb of iterations
n_iter += 1
state['n_iter'] += 1
############################################################
# compute gradient descent direction
############################################################
if state['n_iter'] == 1:
d = flat_grad.neg()
old_dirs = []
old_stps = []
ro = []
H_diag = 1
else:
# do lbfgs update (update memory)
y = flat_grad.sub(prev_flat_grad)
s = d.mul(t)
ys = y.dot(s) # y*s
if ys > 1e-10:
# updating memory
if len(old_dirs) == history_size:
# shift history by one (limited-memory)
old_dirs.pop(0)
old_stps.pop(0)
ro.pop(0)
# store new direction/step
old_dirs.append(y)
old_stps.append(s)
ro.append(1. / ys)
# update scale of initial Hessian approximation
H_diag = ys / y.dot(y) # (y*y)
# compute the approximate (L-BFGS) inverse Hessian
# multiplied by the gradient
num_old = len(old_dirs)
if 'al' not in state:
state['al'] = [None] * history_size
al = state['al']
# iteration in L-BFGS loop collapsed to use just one buffer
q = flat_grad.neg()
for i in range(num_old - 1, -1, -1):
al[i] = old_stps[i].dot(q) * ro[i]
q.add_(old_dirs[i], alpha=-al[i])
# multiply by initial Hessian
# r/d is the final direction
d = r = torch.mul(q, H_diag)
for i in range(num_old):
be_i = old_dirs[i].dot(r) * ro[i]
r.add_(old_stps[i], alpha=al[i] - be_i)
if prev_flat_grad is None:
prev_flat_grad = flat_grad.clone(memory_format=torch.contiguous_format)
else:
prev_flat_grad.copy_(flat_grad)
prev_loss = loss
############################################################
# compute step length
############################################################
# reset initial guess for step size
if state['n_iter'] == 1:
t = min(1., 1. / flat_grad.abs().sum()) * lr
else:
t = lr
# directional derivative
gtd = flat_grad.dot(d) # g * d
# directional derivative is below tolerance
if gtd > -tolerance_change:
break
# optional line search: user function
ls_func_evals = 0
if line_search_fn is not None:
# perform line search, using user function
if line_search_fn != "strong_wolfe":
raise RuntimeError("only 'strong_wolfe' is supported")
else:
x_init = self._clone_param()
def obj_func(x, t, d):
return self._directional_evaluate(closure, x, t, d)
loss, flat_grad, t, ls_func_evals = _strong_wolfe(
obj_func, x_init, t, d, loss, flat_grad, gtd)
self._add_grad(t, d)
opt_cond = flat_grad.abs().max() <= tolerance_grad
else:
# no line search, simply move with fixed-step
self._add_grad(t, d)
if n_iter != max_iter:
# re-evaluate function only if not in last iteration
# the reason we do this: in a stochastic setting,
# no use to re-evaluate that function here
with torch.enable_grad():
loss = float(closure())
flat_grad = self._gather_flat_grad()
opt_cond = flat_grad.abs().max() <= tolerance_grad
ls_func_evals = 1
# update func eval
current_evals += ls_func_evals
state['func_evals'] += ls_func_evals
############################################################
# check conditions
############################################################
if n_iter == max_iter:
break
if current_evals >= max_eval:
break
# optimal condition
if opt_cond:
break
# lack of progress
if d.mul(t).abs().max() <= tolerance_change:
break
if abs(loss - prev_loss) < tolerance_change:
break
state['d'] = d
state['t'] = t
state['old_dirs'] = old_dirs
state['old_stps'] = old_stps
state['ro'] = ro
state['H_diag'] = H_diag
state['prev_flat_grad'] = prev_flat_grad
state['prev_loss'] = prev_loss
return orig_loss
|