Source code for deepxde.nn.initializers

__all__ = ["get", "VarianceScalingStacked"]

import math

from .. import config
from ..backend import backend_name, tf, torch, jax, paddle


[docs] class VarianceScalingStacked: """Initializer capable of adapting its scale to the shape of weights tensors. With `distribution="truncated_normal" or "untruncated_normal"`, samples are drawn from a truncated/untruncated normal distribution with a mean of zero and a standard deviation (after truncation, if used) `stddev = sqrt(scale / n)` where n is: - number of input units in the weight tensor, if mode = "fan_in" - number of output units, if mode = "fan_out" - average of the numbers of input and output units, if mode = "fan_avg" With `distribution="uniform"`, samples are drawn from a uniform distribution within [-limit, limit], with `limit = sqrt(3 * scale / n)`. Args: scale: Scaling factor (positive float). mode: One of "fan_in", "fan_out", "fan_avg". distribution: Random distribution to use. One of "normal", "uniform". seed: A Python integer. Used to create random seeds. See `tf.set_random_seed` for behavior. dtype: Default data type, used if no `dtype` argument is provided when calling the initializer. Only floating point types are supported. Raises: ValueError: In case of an invalid value for the "scale", mode" or "distribution" arguments. """ def __init__( self, scale=1.0, mode="fan_in", distribution="truncated_normal", seed=None, ): if scale <= 0.0: raise ValueError("`scale` must be positive float.") if mode not in {"fan_in", "fan_out", "fan_avg"}: raise ValueError("Invalid `mode` argument:", mode) distribution = distribution.lower() if distribution not in { "normal", "uniform", "truncated_normal", "untruncated_normal", }: raise ValueError("Invalid `distribution` argument:", distribution) self.scale = scale self.mode = mode self.distribution = distribution self.seed = seed self.dtype = config.real(tf) def __call__(self, shape, dtype=None): if dtype is None: dtype = self.dtype scale = self.scale fan_in, fan_out = _compute_fans_stacked(shape) if self.mode == "fan_in": scale /= max(1.0, fan_in) elif self.mode == "fan_out": scale /= max(1.0, fan_out) else: scale /= max(1.0, (fan_in + fan_out) / 2.0) if self.distribution == "normal" or self.distribution == "truncated_normal": # constant taken from scipy.stats.truncnorm.std(a=-2, b=2, loc=0., scale=1.) stddev = math.sqrt(scale) / 0.87962566103423978 return tf.truncated_normal(shape, 0.0, stddev, dtype, seed=self.seed) elif self.distribution == "untruncated_normal": stddev = math.sqrt(scale) return tf.random_normal(shape, 0.0, stddev, dtype, seed=self.seed) else: limit = math.sqrt(3.0 * scale) return tf.random_uniform(shape, -limit, limit, dtype, seed=self.seed)
def _compute_fans_stacked(shape): """Computes the number of input and output units for a weight shape. Args: shape: Integer shape tuple or TF tensor shape. Returns: A tuple of scalars (fan_in, fan_out). """ if len(shape) < 1: # Just to avoid errors for constants. fan_in = fan_out = 1 elif len(shape) == 1: fan_in = fan_out = shape[0] elif len(shape) == 2: # Assuming stacked NN. # kernel shape: (num_stack, fan_in) fan_in = shape[1] fan_out = 1 else: # Assuming stacked NN. # kernel shape: (..., fan_in, fan_out) fan_in = shape[-2] fan_out = shape[-1] return fan_in, fan_out def initializer_dict_tf(): return { "Glorot normal": tf.keras.initializers.glorot_normal(), "Glorot uniform": tf.keras.initializers.glorot_uniform(), "He normal": tf.keras.initializers.he_normal(), "He uniform": tf.keras.initializers.he_uniform(), "LeCun normal": tf.keras.initializers.lecun_normal(), "LeCun uniform": tf.keras.initializers.lecun_uniform(), "Orthogonal": tf.keras.initializers.Orthogonal(), "zeros": tf.zeros_initializer(), # Initializers of stacked DeepONet "stacked He normal": VarianceScalingStacked(scale=2.0), "stacked He uniform": VarianceScalingStacked(scale=2.0, distribution="uniform"), "stacked LeCun normal": VarianceScalingStacked(), "stacked LeCun uniform": VarianceScalingStacked(distribution="uniform"), } def initializer_dict_torch(): return { "Glorot normal": torch.nn.init.xavier_normal_, "Glorot uniform": torch.nn.init.xavier_uniform_, "He normal": torch.nn.init.kaiming_normal_, "He uniform": torch.nn.init.kaiming_uniform_, "zeros": torch.nn.init.zeros_, } def initializer_dict_jax(): return { "Glorot normal": jax.nn.initializers.glorot_normal(), "Glorot uniform": jax.nn.initializers.glorot_uniform(), "He normal": jax.nn.initializers.he_normal(), "He uniform": jax.nn.initializers.he_uniform(), "Lecun normal": jax.nn.initializers.lecun_normal(), "Lecun uniform": jax.nn.initializers.lecun_uniform(), "zeros": jax.nn.initializers.zeros, } def initializer_dict_paddle(): return { "Glorot normal": paddle.nn.initializer.XavierNormal(), "Glorot uniform": paddle.nn.initializer.XavierUniform(), "He normal": paddle.nn.initializer.KaimingNormal(), "He uniform": paddle.nn.initializer.KaimingUniform(), "zeros": paddle.nn.initializer.Constant(0.0), } if backend_name in ["tensorflow.compat.v1", "tensorflow"]: INITIALIZER_DICT = initializer_dict_tf() elif backend_name == "pytorch": INITIALIZER_DICT = initializer_dict_torch() elif backend_name == "jax": INITIALIZER_DICT = initializer_dict_jax() elif backend_name == "paddle": INITIALIZER_DICT = initializer_dict_paddle()
[docs] def get(identifier): """Retrieve an initializer by the identifier. Args: identifier: String that contains the initializer name or an initializer function. Returns: Initializer instance base on the input identifier. """ if isinstance(identifier, str): return INITIALIZER_DICT[identifier] if callable(identifier): return identifier raise ValueError("Could not interpret initializer identifier: " + str(identifier))