deepxde.optimizers.jax package

Submodules

deepxde.optimizers.jax.optimizers module

deepxde.optimizers.jax.optimizers.apply_updates(params: Array | ndarray | bool | number | bool | int | float | complex | Iterable[ArrayTree] | Mapping[Any, ArrayTree], updates: Array | ndarray | bool | number | bool | int | float | complex | Iterable[ArrayTree] | Mapping[Any, ArrayTree]) Array | ndarray | bool | number | bool | int | float | complex | Iterable[ArrayTree] | Mapping[Any, ArrayTree][source]

Applies an update to the corresponding parameters.

This is a utility functions that applies an update to a set of parameters, and then returns the updated parameters to the caller. As an example, the update may be a gradient transformed by a sequence of`GradientTransformations`. This function is exposed for convenience, but it just adds updates and parameters; you may also apply updates to parameters manually, using jax.tree.map (e.g. if you want to manipulate updates in custom ways before applying them).

Parameters:
  • params – a tree of parameters.

  • updates – a tree of updates, the tree structure and the shape of the leaf nodes must match that of params.

Returns:

Updated parameters, with same structure, shape and type as params.

deepxde.optimizers.jax.optimizers.get(optimizer, learning_rate=None, decay=None)[source]

Retrieves an optax Optimizer instance.

deepxde.optimizers.jax.optimizers.is_external_optimizer(optimizer)[source]

Module contents

deepxde.optimizers.jax.apply_updates(params: Array | ndarray | bool | number | bool | int | float | complex | Iterable[ArrayTree] | Mapping[Any, ArrayTree], updates: Array | ndarray | bool | number | bool | int | float | complex | Iterable[ArrayTree] | Mapping[Any, ArrayTree]) Array | ndarray | bool | number | bool | int | float | complex | Iterable[ArrayTree] | Mapping[Any, ArrayTree][source]

Applies an update to the corresponding parameters.

This is a utility functions that applies an update to a set of parameters, and then returns the updated parameters to the caller. As an example, the update may be a gradient transformed by a sequence of`GradientTransformations`. This function is exposed for convenience, but it just adds updates and parameters; you may also apply updates to parameters manually, using jax.tree.map (e.g. if you want to manipulate updates in custom ways before applying them).

Parameters:
  • params – a tree of parameters.

  • updates – a tree of updates, the tree structure and the shape of the leaf nodes must match that of params.

Returns:

Updated parameters, with same structure, shape and type as params.

deepxde.optimizers.jax.get(optimizer, learning_rate=None, decay=None)[source]

Retrieves an optax Optimizer instance.

deepxde.optimizers.jax.is_external_optimizer(optimizer)[source]