import numpy as np
from ...backend import tf
[docs]
class NN(tf.keras.Model):
"""Base class for all neural network modules."""
def __init__(self):
super().__init__()
self.regularizer = None
self._auxiliary_vars = None
self._input_transform = None
self._output_transform = None
@property
def auxiliary_vars(self):
"""Tensors: Any additional variables needed."""
return self._auxiliary_vars
@auxiliary_vars.setter
def auxiliary_vars(self, value):
self._auxiliary_vars = value
[docs]
def num_trainable_parameters(self):
"""Evaluate the number of trainable parameters for the NN."""
result = np.sum(
[np.prod(v.get_shape().as_list()) for v in self.trainable_variables]
)
if result == 0:
print(
"Warning: The net has to be trained first. \
You need to create a model and run model.compile() and model.train() \
in order to initialize the trainable_variables for the net."
)
return result