Source code for deepxde.nn.tensorflow_compat_v1.nn

import numpy as np

from ... import config
from ...backend import tf
from ...utils import make_dict, timing


[docs] class NN: """Base class for all neural network modules.""" def __init__(self): self.training = tf.placeholder(tf.bool) self.regularizer = None self._auxiliary_vars = tf.placeholder(config.real(tf), [None, None]) self._input_transform = None self._output_transform = None self._built = False # The property will be set upon call of self.build() @property def inputs(self): """Return the net inputs (placeholders).""" @property def outputs(self): """Return the net outputs (tf.Tensor).""" @property def targets(self): """Return the targets of the net outputs (placeholders).""" @property def auxiliary_vars(self): """Return additional variables needed (placeholders).""" return self._auxiliary_vars @property def built(self): return self._built @built.setter def built(self, value): self._built = value
[docs] def feed_dict(self, training, inputs, targets=None, auxiliary_vars=None): """Construct a feed_dict to feed values to TensorFlow placeholders.""" feed_dict = {self.training: training} feed_dict.update(self._feed_dict_inputs(inputs)) if targets is not None: feed_dict.update(self._feed_dict_targets(targets)) if auxiliary_vars is not None: feed_dict.update(self._feed_dict_auxiliary_vars(auxiliary_vars)) return feed_dict
def _feed_dict_inputs(self, inputs): return make_dict(self.inputs, inputs) def _feed_dict_targets(self, targets): return make_dict(self.targets, targets) def _feed_dict_auxiliary_vars(self, auxiliary_vars): return make_dict(self.auxiliary_vars, auxiliary_vars)
[docs] def apply_feature_transform(self, transform): """Compute the features by appling a transform to the network inputs, i.e., features = transform(inputs). Then, outputs = network(features). """ self._input_transform = transform
[docs] def apply_output_transform(self, transform): """Apply a transform to the network outputs, i.e., outputs = transform(inputs, outputs). """ self._output_transform = transform
[docs] def num_trainable_parameters(self): """Evaluate the number of trainable parameters for the NN. Notice that the function returns the number of trainable parameters for the whole tf.Session, so that it will not be correct if several nets are defined within the same tf.Session. """ return np.sum( [np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()] )
[docs] @timing def build(self): """Construct the network.""" self.built = True