import paddle
from .fnn import FNN
from .nn import NN
from .. import activations
from .. import initializers
[docs]
class DeepONet(NN):
"""Deep operator network.
`Lu et al. Learning nonlinear operators via DeepONet based on the universal
approximation theorem of operators. Nat Mach Intell, 2021.
<https://doi.org/10.1038/s42256-021-00302-5>`_
Args:
layer_sizes_branch: A list of integers as the width of a fully connected
network, or `(dim, f)` where `dim` is the input dimension and `f` is a
network function. The width of the last layer in the branch and trunk net
should be equal.
layer_sizes_trunk (list): A list of integers as the width of a fully connected
network.
activation: If `activation` is a ``string``, then the same activation is used in
both trunk and branch nets. If `activation` is a ``dict``, then the trunk
net uses the activation `activation["trunk"]`, and the branch net uses
`activation["branch"]`.
"""
def __init__(
self,
layer_sizes_branch,
layer_sizes_trunk,
activation,
kernel_initializer,
use_bias=True,
):
super().__init__()
self.layer_sizes_func = layer_sizes_branch
self.layer_sizes_loc = layer_sizes_trunk
if isinstance(activation, dict):
self.activation_branch = activations.get(activation["branch"])
self.activation_trunk = activations.get(activation["trunk"])
else:
activation_branch = self.activation_trunk = activations.get(activation)
self.kernel_initializer = initializers.get(kernel_initializer)
if callable(layer_sizes_branch[1]):
# User-defined network
self.branch = layer_sizes_branch[1]
else:
# Fully connected network
self.branch = FNN(layer_sizes_branch, activation_branch, kernel_initializer)
self.trunk = FNN(layer_sizes_trunk, self.activation_trunk, kernel_initializer)
self.use_bias = use_bias
if use_bias:
# register bias to parameter for updating in optimizer and storage
self.b = self.create_parameter(
shape=(1,), default_initializer=initializers.get("zeros")
)
[docs]
def forward(self, inputs):
x_func = inputs[0]
x_loc = inputs[1]
# Branch net to encode the input function
x_func = self.branch(x_func)
# Trunk net to encode the domain of the output function
if self._input_transform is not None:
x_loc = self._input_transform(x_loc)
x_loc = self.activation_trunk(self.trunk(x_loc))
# Dot product
if x_func.shape[-1] != x_loc.shape[-1]:
raise AssertionError(
"Output sizes of branch net and trunk net do not match."
)
x = paddle.sum(x_func * x_loc, axis=1, keepdim=True)
# Add bias
if self.use_bias:
x += self.b
if self._output_transform is not None:
x = self._output_transform(inputs, x)
return x
[docs]
class DeepONetCartesianProd(NN):
"""Deep operator network for dataset in the format of Cartesian product.
Args:
layer_sizes_branch: A list of integers as the width of a fully connected network,
or `(dim, f)` where `dim` is the input dimension and `f` is a network
function. The width of the last layer in the branch and trunk net should be
equal.
layer_sizes_trunk (list): A list of integers as the width of a fully connected
network.
activation: If `activation` is a ``string``, then the same activation is used in
both trunk and branch nets. If `activation` is a ``dict``, then the trunk
net uses the activation `activation["trunk"]`, and the branch net uses
`activation["branch"]`.
"""
def __init__(
self,
layer_sizes_branch,
layer_sizes_trunk,
activation,
kernel_initializer,
regularization=None,
):
super().__init__()
if isinstance(activation, dict):
activation_branch = activation["branch"]
self.activation_trunk = activations.get(activation["trunk"])
else:
activation_branch = self.activation_trunk = activations.get(activation)
if callable(layer_sizes_branch[1]):
# User-defined network
self.branch = layer_sizes_branch[1]
else:
# Fully connected network
self.branch = FNN(layer_sizes_branch, activation_branch, kernel_initializer)
self.trunk = FNN(layer_sizes_trunk, self.activation_trunk, kernel_initializer)
# register bias to parameter for updating in optimizer and storage
self.b = self.create_parameter(
shape=(1,), default_initializer=initializers.get("zeros")
)
self.regularizer = regularization
[docs]
def forward(self, inputs):
x_func = inputs[0]
x_loc = inputs[1]
# Branch net to encode the input function
x_func = self.branch(x_func)
# Trunk net to encode the domain of the output function
if self._input_transform is not None:
x_loc = self._input_transform(x_loc)
x_loc = self.activation_trunk(self.trunk(x_loc))
# Dot product
if x_func.shape[-1] != x_loc.shape[-1]:
raise AssertionError(
"Output sizes of branch net and trunk net do not match."
)
x = x_func @ x_loc.T
# Add bias
x += self.b
if self._output_transform is not None:
x = self._output_transform(inputs, x)
return x