from typing import Any, Callable
import jax
from flax import linen as nn
from .nn import NN
from .. import activations
from .. import initializers
[docs]
class FNN(NN):
"""Fully-connected neural network."""
layer_sizes: Any
activation: Any
kernel_initializer: Any
params: Any = None
_input_transform: Callable = None
_output_transform: Callable = None
[docs]
def setup(self):
# TODO: implement get regularizer
if isinstance(self.activation, list):
if not (len(self.layer_sizes) - 1) == len(self.activation):
raise ValueError(
"Total number of activation functions do not match with sum of hidden layers and output layer!"
)
self._activation = list(map(activations.get, self.activation))
else:
self._activation = activations.get(self.activation)
kernel_initializer = initializers.get(self.kernel_initializer)
initializer = jax.nn.initializers.zeros
self.denses = [
nn.Dense(
unit,
kernel_init=kernel_initializer,
bias_init=initializer,
)
for unit in self.layer_sizes[1:]
]
def __call__(self, inputs, training=False):
x = inputs
if self._input_transform is not None:
x = self._input_transform(x)
for j, linear in enumerate(self.denses[:-1]):
x = (
self._activation[j](linear(x))
if isinstance(self._activation, list)
else self._activation(linear(x))
)
x = self.denses[-1](x)
if self._output_transform is not None:
x = self._output_transform(inputs, x)
return x