Source code for deepxde.nn.jax.nn
from flax import linen as nn
[docs]
class NN(nn.Module):
"""Base class for all neural network modules."""
# All sub-modules should have the following variables:
# params: Any = None
# _input_transform: Optional[Callable] = None
# _output_transform: Optional[Callable] = None
[docs]
def apply_feature_transform(self, transform):
"""Compute the features by appling a transform to the network inputs, i.e.,
features = transform(inputs). Then, outputs = network(features).
"""
def transform_handling_flat(x):
"""Handle inputs of shape (n,)"""
# TODO: Support tuple or list inputs.
if isinstance(x, (list, tuple)):
return transform(x)
if x.ndim == 1:
return transform(x.reshape(1, -1)).reshape(-1)
return transform(x)
self._input_transform = transform_handling_flat
[docs]
def apply_output_transform(self, transform):
"""Apply a transform to the network outputs, i.e.,
outputs = transform(inputs, outputs).
"""
def transform_handling_flat(inputs, outputs):
"""Handle inputs of shape (n,)"""
# TODO: Support tuple or list inputs.
if isinstance(inputs, (list, tuple)):
return transform(inputs, outputs)
if inputs.ndim == 1:
return transform(inputs.reshape(1, -1), outputs.reshape(1, -1)).reshape(-1)
return transform(inputs, outputs)
self._output_transform = transform_handling_flat