319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405 | class GradientGP(GaussianProcessExpr):
"""
return ∇f, Hₓₓ k_f, ∇ covar_fg
"""
def __init__(self, f, x_shape, grad_check=False, analytical_hessian=True):
self.gp = f
self.x_shape = x_shape
self.grad_check = grad_check
self.analytical_hessian = analytical_hessian
@property
def shape(self):
return self.x_shape
@property
def dtype(self):
return self.gp.dtype
def to(self, dtype):
self.gp.to(dtype)
def mean(self, x):
f = self.gp
if self.grad_check:
old_dtype = self.dtype
self.to(torch.float64)
with variable_required_grad(x):
torch.autograd.gradcheck(f.mean, x.double())
self.to(dtype=old_dtype)
with variable_required_grad(x):
return torch.autograd.grad(f.mean(x), x)[0]
def knl(self, x, xp, eigeps=EPS):
f = self.gp
if xp is x:
xp = xp.detach().clone()
grad_k_func = lambda xs, xt: torch.autograd.grad(
f.knl(xs, xt), xs, create_graph=True)[0]
if self.grad_check:
old_dtype = self.dtype
self.to(torch.float64)
f_knl_func = lambda xt: f.knl(xt, xp.double())
with variable_required_grad(x):
torch.autograd.gradcheck(f_knl_func, x.double())
torch.autograd.gradgradcheck(lambda x: f.knl(x, x), x.double())
with variable_required_grad(x):
with variable_required_grad(xp):
torch.autograd.gradcheck(
lambda xp: grad_k_func(x.double(), xp)[0], xp.double())
self.to(dtype=old_dtype)
analytical = self.analytical_hessian
if analytical:
Hxx_k = t_hessian(f.knl, x, xp)
else:
with variable_required_grad(x):
with variable_required_grad(xp):
old_dtype = self.dtype
self.to(torch.float64)
Hxx_k = tgradcheck.get_numerical_jacobian(
partial(grad_k_func, x.double()), xp.double())
self.to(dtype=old_dtype)
Hxx_k = Hxx_k.to(old_dtype)
if torch.allclose(x, xp):
eigenvalues, eigenvectors = torch.eig(Hxx_k, eigenvectors=False)
assert (eigenvalues[:, 0] > -eigeps).all(), " Hessian must be positive definite"
small_neg_eig = ((eigenvalues[:, 0] > -eigeps) & (eigenvalues[:, 0] < 0))
if small_neg_eig.any():
eigenvalues, eigenvectors = torch.eig(Hxx_k, eigenvectors=True)
evalz = eigenvalues[:, 0]
evalz[small_neg_eig] = 0
Hxx_k = eigenvectors.T @ torch.diag(evalz) @ eigenvectors
return Hxx_k
def covar(self, G, x, xp):
"""
returns covar(∇f, g) given covar(f, g)
"""
f = self.gp
with variable_required_grad(x):
J_covar_fg = t_jac(self.gp.covar(G, x, xp), x)
return J_covar_fg.t()
def __str__(self):
return "∇ {self.gp!s}".format(self=self)
|