Source code for optax._src.update

# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Apply transformed gradient updates to parameters."""

import jax
import jax.numpy as jnp
from optax._src import base


[docs] def apply_updates(params: base.Params, updates: base.Updates) -> base.Params: """Applies an update to the corresponding parameters. This is a utility functions that applies an update to a set of parameters, and then returns the updated parameters to the caller. As an example, the update may be a gradient transformed by a sequence of`GradientTransformations`. This function is exposed for convenience, but it just adds updates and parameters; you may also apply updates to parameters manually, using `jax.tree.map` (e.g. if you want to manipulate updates in custom ways before applying them). Args: params: a tree of parameters. updates: a tree of updates, the tree structure and the shape of the leaf nodes must match that of `params`. Returns: Updated parameters, with same structure, shape and type as `params`. """ return jax.tree.map( lambda p, u: ( None if p is None else jnp.asarray(p + u).astype(jnp.asarray(p).dtype) ), params, updates, is_leaf=lambda x: x is None, )
def incremental_update( new_tensors: base.Params, old_tensors: base.Params, step_size: jax.typing.ArrayLike) -> base.Params: """Incrementally update parameters via polyak averaging. Polyak averaging tracks an (exponential moving) average of the past parameters of a model, for use at test/evaluation time. Args: new_tensors: the latest value of the tensors. old_tensors: a moving average of the values of the tensors. step_size: the step_size used to update the polyak average on each step. Returns: an updated moving average `step_size*new+(1-step_size)*old` of the params. References: [Polyak et al, 1991](https://epubs.siam.org/doi/10.1137/0330046) """ return jax.tree.map( lambda new, old: ( None if new is None else step_size * new + (1.0 - step_size) * old ), new_tensors, old_tensors, is_leaf=lambda x: x is None, ) def periodic_update( new_tensors: base.Params, old_tensors: base.Params, steps: jax.typing.ArrayLike, # int update_period: jax.typing.ArrayLike, # int ) -> base.Params: """Periodically update all parameters with new values. A slow copy of a model's parameters, updated every K actual updates, can be used to implement forms of self-supervision (in supervised learning), or to stabilize temporal difference learning updates (in reinforcement learning). Args: new_tensors: the latest value of the tensors. old_tensors: a slow copy of the model's parameters. steps: number of update steps on the "online" network. update_period: every how many steps to update the "target" network. Returns: a slow copy of the model's parameters, updated every `update_period` steps. References: [Grill et al., 2020](https://arxiv.org/abs/2006.07733) [Mnih et al., 2015](https://arxiv.org/abs/1312.5602) """ return jax.lax.cond( jnp.mod(steps, update_period) == 0, lambda _: new_tensors, lambda _: old_tensors, None, )