Source code for deepxde.optimizers.pytorch.nncg

from functools import reduce

import torch
from torch.func import vmap
from torch.optim import Optimizer


def _armijo(f, x, gx, dx, t, alpha=0.1, beta=0.5):
    """Line search to find a step size that satisfies the Armijo condition."""
    f0 = f(x, 0, dx)
    f1 = f(x, t, dx)
    while f1 > f0 + alpha * t * gx.dot(dx):
        t *= beta
        f1 = f(x, t, dx)
    return t


def _apply_nys_precond_inv(U, S_mu_inv, mu, lambd_r, x):
    """Applies the inverse of the Nystrom approximation of the Hessian to a vector."""
    z = U.T @ x
    z = (lambd_r + mu) * (U @ (S_mu_inv * z)) + (x - U @ z)
    return z


def _nystrom_pcg(hess, b, x, mu, U, S, r, tol, max_iters):
    """Solves a positive-definite linear system using NyströmPCG.

    `Frangella et al. Randomized Nyström Preconditioning.
    SIAM Journal on Matrix Analysis and Applications, 2023.
    <https://epubs.siam.org/doi/10.1137/21M1466244>`
    """
    lambd_r = S[r - 1]
    S_mu_inv = (S + mu) ** (-1)

    resid = b - (hess(x) + mu * x)
    with torch.no_grad():
        z = _apply_nys_precond_inv(U, S_mu_inv, mu, lambd_r, resid)
        p = z.clone()

    i = 0

    while torch.norm(resid) > tol and i < max_iters:
        v = hess(p) + mu * p
        with torch.no_grad():
            alpha = torch.dot(resid, z) / torch.dot(p, v)
            x += alpha * p

            rTz = torch.dot(resid, z)
            resid -= alpha * v
            z = _apply_nys_precond_inv(U, S_mu_inv, mu, lambd_r, resid)
            beta = torch.dot(resid, z) / rTz

            p = z + beta * p

        i += 1

    if torch.norm(resid) > tol:
        print(
            "Warning: PCG did not converge to tolerance. "
            f"Tolerance was {tol} but norm of residual is {torch.norm(resid)}"
        )

    return x


[docs] class NNCG(Optimizer): """Implementation of NysNewtonCG, a damped Newton-CG method that uses Nyström preconditioning. `Rathore et al. Challenges in Training PINNs: A Loss Landscape Perspective. Preprint, 2024. <https://arxiv.org/abs/2402.01868>` .. warning:: This optimizer doesn't support per-parameter options and parameter groups (there can be only one). NOTE: This optimizer is currently a beta version. Our implementation is inspired by the PyTorch implementation of `L-BFGS <https://pytorch.org/docs/stable/_modules/torch/optim/lbfgs.html#LBFGS>`. The parameters rank and mu will probably need to be tuned for your specific problem. If the optimizer is running very slowly, you can try one of the following: - Increase the rank (this should increase the accuracy of the Nyström approximation in PCG) - Reduce cg_tol (this will allow PCG to terminate with a less accurate solution) - Reduce cg_max_iters (this will allow PCG to terminate after fewer iterations) Args: params (iterable): iterable of parameters to optimize or dicts defining parameter groups lr (float, optional): learning rate (default: 1.0) rank (int, optional): rank of the Nyström approximation (default: 10) mu (float, optional): damping parameter (default: 1e-4) update_freq (int, optional): frequency of updating the preconditioner chunk_size (int, optional): number of Hessian-vector products to be computed in parallel (default: 1) cg_tol (float, optional): tolerance for PCG (default: 1e-16) cg_max_iters (int, optional): maximum number of PCG iterations (default: 1000) line_search_fn (str, optional): either 'armijo' or None (default: None) verbose (bool, optional): verbosity (default: False) """ def __init__( self, params, lr=1.0, rank=10, mu=1e-4, update_freq=20, chunk_size=1, cg_tol=1e-16, cg_max_iters=1000, line_search_fn=None, verbose=False, ): defaults = { "lr": lr, "rank": rank, "mu": mu, "update_freq": update_freq, "chunk_size": chunk_size, "cg_tol": cg_tol, "cg_max_iters": cg_max_iters, "line_search_fn": line_search_fn, } self.rank = rank self.mu = mu self.update_freq = update_freq self.chunk_size = chunk_size self.cg_tol = cg_tol self.cg_max_iters = cg_max_iters self.line_search_fn = line_search_fn self.verbose = verbose self.U = None self.S = None self.n_iters = 0 super().__init__(params, defaults) if len(self.param_groups) > 1: raise ValueError( "NNCG doesn't currently support " "per-parameter options (parameter groups)" ) if self.line_search_fn is not None and self.line_search_fn != "armijo": raise ValueError("NNCG only supports Armijo line search") self._params = self.param_groups[0]["params"] self._params_list = list(self._params) self._numel_cache = None
[docs] def step(self, closure): """Perform a single optimization step. Args: closure (callable): A closure that reevaluates the model and returns the loss w.r.t. the parameters. """ if self.n_iters == 0: # Store the previous direction for warm starting PCG self.old_dir = torch.zeros(self._numel(), device=self._params[0].device) loss = closure() # Compute gradient via torch.autograd.grad g_tuple = torch.autograd.grad(loss, self._params_list, create_graph=True) g = torch.cat([gi.view(-1) for gi in g_tuple if gi is not None]) if self.n_iters % self.update_freq == 0: self._update_preconditioner(g) # One step update for group_idx, group in enumerate(self.param_groups): def hvp_temp(x): return self._hvp(g, self._params_list, x) # Calculate the Newton direction d = _nystrom_pcg( hvp_temp, g, self.old_dir, self.mu, self.U, self.S, self.rank, self.cg_tol, self.cg_max_iters, ) # Store the previous direction for warm starting PCG self.old_dir = d # Check if d is a descent direction if torch.dot(d, g) <= 0: print("Warning: d is not a descent direction") if self.line_search_fn == "armijo": x_init = self._clone_param() def obj_func(x, t, dx): self._add_grad(t, dx) loss = float(closure()) self._set_param(x) return loss # Use -d for convention t = _armijo(obj_func, x_init, g, -d, group["lr"]) else: t = group["lr"] self.state[group_idx]["t"] = t # update parameters ls = 0 for p in group["params"]: np = torch.numel(p) dp = d[ls : ls + np].view(p.shape) ls += np p.data.add_(-dp, alpha=t) self.n_iters += 1 return loss
def _update_preconditioner(self, grad): """Update the Nyström approximation of the Hessian. Args: grad (torch.Tensor): gradient of the loss w.r.t. the parameters. """ # Generate test matrix (NOTE: This is transposed test matrix) p = grad.shape[0] Phi = torch.randn((self.rank, p), device=grad.device) / (p**0.5) Phi = torch.linalg.qr(Phi.t(), mode="reduced")[0].t() Y = self._hvp_vmap(grad, self._params_list)(Phi) # Calculate shift shift = torch.finfo(Y.dtype).eps Y_shifted = Y + shift * Phi # Calculate Phi^T * H * Phi (w/ shift) for Cholesky choleskytarget = torch.mm(Y_shifted, Phi.t()) # Perform Cholesky, if fails, do eigendecomposition # The new shift is the abs of smallest eigenvalue (negative) # plus the original shift try: C = torch.linalg.cholesky(choleskytarget) except torch.linalg.LinAlgError: # eigendecomposition, eigenvalues and eigenvector matrix eigs, eigvectors = torch.linalg.eigh(choleskytarget) shift = shift + torch.abs(torch.min(eigs)) # add shift to eigenvalues eigs = eigs + shift # put back the matrix for Cholesky by eigenvector * eigenvalues # after shift * eigenvector^T C = torch.linalg.cholesky( torch.mm(eigvectors, torch.mm(torch.diag(eigs), eigvectors.T)) ) try: B = torch.linalg.solve_triangular(C, Y_shifted, upper=False, left=True) # temporary fix for issue @ https://github.com/pytorch/pytorch/issues/97211 except RuntimeError: B = torch.linalg.solve_triangular( C.to("cpu"), Y_shifted.to("cpu"), upper=False, left=True ).to(C.device) # B = V * S * U^T b/c we have been using transposed sketch _, S, UT = torch.linalg.svd(B, full_matrices=False) self.U = UT.t() self.S = torch.max(torch.square(S) - shift, torch.tensor(0.0)) self.rho = self.S[-1] if self.verbose: print(f"Approximate eigenvalues = {self.S}") def _hvp_vmap(self, grad_params, params): return vmap( lambda v: self._hvp(grad_params, params, v), in_dims=0, chunk_size=self.chunk_size, ) def _hvp(self, grad_params, params, v): Hv = torch.autograd.grad(grad_params, params, grad_outputs=v, retain_graph=True) Hv = tuple(Hvi.detach() for Hvi in Hv) return torch.cat([Hvi.reshape(-1) for Hvi in Hv]) def _numel(self): if self._numel_cache is None: self._numel_cache = reduce( lambda total, p: total + p.numel(), self._params, 0 ) return self._numel_cache def _add_grad(self, step_size, update): offset = 0 for p in self._params: numel = p.numel() # Avoid in-place operation by creating a new tensor p.data = p.data.add( update[offset : offset + numel].view_as(p), alpha=step_size ) offset += numel assert offset == self._numel() def _clone_param(self): return [p.clone(memory_format=torch.contiguous_format) for p in self._params] def _set_param(self, params_data): for p, pdata in zip(self._params, params_data): # Replace the .data attribute of the tensor p.data = pdata.data