import torch
from .nn import NN
from .. import activations
from .. import initializers
from ... import config
[docs]
class FNN(NN):
"""Fully-connected neural network."""
def __init__(
self,
layer_sizes,
activation,
kernel_initializer,
regularization=None,
dropout_rate=0,
):
super().__init__()
if isinstance(activation, list):
if not (len(layer_sizes) - 1) == len(activation):
raise ValueError(
"Total number of activation functions do not match with sum of hidden layers and output layer!"
)
self.activation = list(map(activations.get, activation))
else:
self.activation = activations.get(activation)
if isinstance(dropout_rate, list):
if len(layer_sizes) - 1 != len(dropout_rate):
raise ValueError(
f"Number of dropout rates must be equal to {len(layer_sizes) - 1}"
)
self.dropout_rate = dropout_rate
else:
self.dropout_rate = [dropout_rate] * (len(layer_sizes) - 1)
initializer = initializers.get(kernel_initializer)
initializer_zero = initializers.get("zeros")
self.regularizer = regularization
self.linears = torch.nn.ModuleList()
for i in range(1, len(layer_sizes)):
self.linears.append(
torch.nn.Linear(
layer_sizes[i - 1], layer_sizes[i], dtype=config.real(torch)
)
)
initializer(self.linears[-1].weight)
initializer_zero(self.linears[-1].bias)
[docs]
def forward(self, inputs):
x = inputs
if self._input_transform is not None:
x = self._input_transform(x)
for j, linear in enumerate(self.linears[:-1]):
x = (
self.activation[j](linear(x))
if isinstance(self.activation, list)
else self.activation(linear(x))
)
if self.dropout_rate[j] > 0:
x = torch.nn.functional.dropout(
x, p=self.dropout_rate[j], training=self.training
)
x = self.linears[-1](x)
if self._output_transform is not None:
x = self._output_transform(inputs, x)
return x
[docs]
class PFNN(NN):
"""Parallel fully-connected network that uses independent sub-networks for each
network output.
Args:
layer_sizes: A nested list that defines the architecture of the neural network
(how the layers are connected). If `layer_sizes[i]` is an int, it represents
one layer shared by all the outputs; if `layer_sizes[i]` is a list, it
represents `len(layer_sizes[i])` sub-layers, each of which is exclusively
used by one output. Every list in `layer_sizes` must have the same length
(= number of subnetworks). If the last element of `layer_sizes` is an int
preceded by a list, it must be equal to the number of subnetworks: all
subnetworks have an output size of 1 and are then concatenated. If the last
element is a list, it specifies the output size for each subnetwork before
concatenation.
activation: Activation function.
kernel_initializer: Initializer for the kernel weights.
"""
def __init__(self, layer_sizes, activation, kernel_initializer):
super().__init__()
self.activation = activations.get(activation)
initializer = initializers.get(kernel_initializer)
initializer_zero = initializers.get("zeros")
if len(layer_sizes) <= 1:
raise ValueError("must specify input and output sizes")
if not isinstance(layer_sizes[0], int):
raise ValueError("input size must be integer")
# Determine the number of subnetworks from the first list layer
list_layers = [
layer for layer in layer_sizes if isinstance(layer, (list, tuple))
]
if not list_layers:
raise ValueError(
"No list layers found; use FNN instead of PFNN for single subnetwork."
)
n_subnetworks = len(list_layers[0])
for layer in list_layers:
if len(layer) != n_subnetworks:
raise ValueError(
"All list layers must have the same length as the first list layer."
)
# Validate output layer if preceded by a list layer
if (
isinstance(layer_sizes[-1], int)
and isinstance(layer_sizes[-2], (list, tuple))
and layer_sizes[-1] != n_subnetworks
):
raise ValueError(
"If last layer is an int and previous is a list, the int must equal the number of subnetworks."
)
def make_linear(n_input, n_output):
linear = torch.nn.Linear(n_input, n_output, dtype=config.real(torch))
initializer(linear.weight)
initializer_zero(linear.bias)
return linear
self.layers = torch.nn.ModuleList()
# Process hidden layers (excluding the output layer)
for i in range(1, len(layer_sizes) - 1):
prev_layer = layer_sizes[i - 1]
curr_layer = layer_sizes[i]
if isinstance(curr_layer, (list, tuple)):
# Parallel layer
if isinstance(prev_layer, (list, tuple)):
# Previous is parallel: each subnetwork input is previous subnetwork output
sub_layers = [
make_linear(prev_layer[j], curr_layer[j])
for j in range(n_subnetworks)
]
else:
# Previous is shared: all subnetworks take the same input
sub_layers = [
make_linear(prev_layer, curr_layer[j])
for j in range(n_subnetworks)
]
self.layers.append(torch.nn.ModuleList(sub_layers))
else:
# Shared layer
if isinstance(prev_layer, (list, tuple)):
# Previous is parallel: concatenate outputs
input_size = sum(prev_layer)
else:
input_size = prev_layer
self.layers.append(make_linear(input_size, curr_layer))
# Process output layer
prev_output_layer = layer_sizes[-2]
output_layer = layer_sizes[-1]
if isinstance(output_layer, (list, tuple)):
if isinstance(prev_output_layer, (list, tuple)):
# Each subnetwork input is corresponding previous output
output_layers = [
make_linear(prev_output_layer[j], output_layer[j])
for j in range(n_subnetworks)
]
else:
# All subnetworks take the same shared input
output_layers = [
make_linear(prev_output_layer, output_layer[j])
for j in range(n_subnetworks)
]
self.layers.append(torch.nn.ModuleList(output_layers))
else:
if isinstance(prev_output_layer, (list, tuple)):
# Each subnetwork outputs 1 and concatenates to output_layer size
output_layers = [
make_linear(prev_output_layer[j], 1) for j in range(n_subnetworks)
]
self.layers.append(torch.nn.ModuleList(output_layers))
else:
# Shared output layer
self.layers.append(make_linear(prev_output_layer, output_layer))
[docs]
def forward(self, inputs):
x = inputs
if self._input_transform is not None:
x = self._input_transform(x)
for layer in self.layers[:-1]:
if isinstance(layer, torch.nn.ModuleList):
# Parallel layer processing
if isinstance(x, list):
x = [self.activation(f(x_)) for f, x_ in zip(layer, x)]
else:
x = [self.activation(f(x)) for f in layer]
else:
# Shared layer processing (concatenate if necessary)
if isinstance(x, list):
x = torch.cat(x, dim=1)
x = self.activation(layer(x))
# Output layer processing
if isinstance(x, list):
x = torch.cat([f(x_) for f, x_ in zip(self.layers[-1], x)], dim=1)
else:
x = self.layers[-1](x)
if self._output_transform is not None:
x = self._output_transform(inputs, x)
return x