import numpy as np
from ... import config
from ...backend import tf
from ...utils import make_dict, timing
[docs]
class NN:
"""Base class for all neural network modules."""
def __init__(self):
self.training = tf.placeholder(tf.bool)
self.regularizer = None
# tf.layers.dense() is not available for TensorFlow 2.16+ with Keras 3. The
# corresponding layer is tf.keras.layers.Dense(). However, tf.layers.dense()
# adds regularizer loss to the collection REGULARIZATION_LOSSES, which can be
# accessed by tf.losses.get_regularization_loss(), but tf.keras.layers.Dense()
# adds regularizer loss to Layer.losses. Hence, we use self.regularization_loss
# to collect tf.keras.layers.Dense() regularization loss.
# References:
# - https://github.com/tensorflow/tensorflow/issues/21587
self.regularization_loss = 0
self._auxiliary_vars = tf.placeholder(config.real(tf), [None, None])
self._input_transform = None
self._output_transform = None
self._built = False # The property will be set upon call of self.build()
@property
def inputs(self):
"""Return the net inputs (placeholders)."""
@property
def outputs(self):
"""Return the net outputs (tf.Tensor)."""
@property
def targets(self):
"""Return the targets of the net outputs (placeholders)."""
@property
def auxiliary_vars(self):
"""Return additional variables needed (placeholders)."""
return self._auxiliary_vars
@property
def built(self):
return self._built
@built.setter
def built(self, value):
self._built = value
[docs]
def feed_dict(self, training, inputs, targets=None, auxiliary_vars=None):
"""Construct a feed_dict to feed values to TensorFlow placeholders."""
feed_dict = {self.training: training}
feed_dict.update(self._feed_dict_inputs(inputs))
if targets is not None:
feed_dict.update(self._feed_dict_targets(targets))
if auxiliary_vars is not None:
feed_dict.update(self._feed_dict_auxiliary_vars(auxiliary_vars))
return feed_dict
def _feed_dict_inputs(self, inputs):
return make_dict(self.inputs, inputs)
def _feed_dict_targets(self, targets):
return make_dict(self.targets, targets)
def _feed_dict_auxiliary_vars(self, auxiliary_vars):
return make_dict(self.auxiliary_vars, auxiliary_vars)
[docs]
def num_trainable_parameters(self):
"""Evaluate the number of trainable parameters for the NN.
Notice that the function returns the number of trainable parameters for the
whole tf.Session, so that it will not be correct if several nets are defined
within the same tf.Session.
"""
return np.sum(
[np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()]
)
[docs]
@timing
def build(self):
"""Construct the network."""
self.built = True