Source code for deepxde.nn.jax.fnn

from typing import Any, Callable

import jax
from flax import linen as nn

from .nn import NN
from .. import activations
from .. import initializers


[docs] class FNN(NN): """Fully-connected neural network.""" layer_sizes: Any activation: Any kernel_initializer: Any params: Any = None _input_transform: Callable = None _output_transform: Callable = None
[docs] def setup(self): # TODO: implement get regularizer if isinstance(self.activation, list): if not (len(self.layer_sizes) - 1) == len(self.activation): raise ValueError( "Total number of activation functions do not match with sum of hidden layers and output layer!" ) self._activation = list(map(activations.get, self.activation)) else: self._activation = activations.get(self.activation) kernel_initializer = initializers.get(self.kernel_initializer) initializer = jax.nn.initializers.zeros self.denses = [ nn.Dense( unit, kernel_init=kernel_initializer, bias_init=initializer, ) for unit in self.layer_sizes[1:] ]
def __call__(self, inputs, training=False): x = inputs if self._input_transform is not None: x = self._input_transform(x) for j, linear in enumerate(self.denses[:-1]): x = ( self._activation[j](linear(x)) if isinstance(self._activation, list) else self._activation(linear(x)) ) x = self.denses[-1](x) if self._output_transform is not None: x = self._output_transform(inputs, x) return x