Source code for deepxde.nn.jax.nn
from flax import linen as nn
[docs]
class NN(nn.Module):
"""Base class for all neural network modules."""
# All sub-modules should have the following variables:
# params: Any = None
# _input_transform: Optional[Callable] = None
# _output_transform: Optional[Callable] = None
[docs]
def apply_feature_transform(self, transform):
"""Compute the features by appling a transform to the network inputs, i.e.,
features = transform(inputs). Then, outputs = network(features).
"""
self._input_transform = transform
[docs]
def apply_output_transform(self, transform):
"""Apply a transform to the network outputs, i.e.,
outputs = transform(inputs, outputs).
"""
self._output_transform = transform