Source code for deepxde.nn.pytorch.nn
import torch
[docs]
class NN(torch.nn.Module):
"""Base class for all neural network modules."""
def __init__(self):
super().__init__()
self.regularizer = None
self._auxiliary_vars = None
self._input_transform = None
self._output_transform = None
@property
def auxiliary_vars(self):
"""Tensors: Any additional variables needed."""
return self._auxiliary_vars
@auxiliary_vars.setter
def auxiliary_vars(self, value):
self._auxiliary_vars = value
[docs]
def apply_feature_transform(self, transform):
"""Compute the features by appling a transform to the network inputs, i.e.,
features = transform(inputs). Then, outputs = network(features).
"""
self._input_transform = transform
[docs]
def apply_output_transform(self, transform):
"""Apply a transform to the network outputs, i.e.,
outputs = transform(inputs, outputs).
"""
self._output_transform = transform
[docs]
def num_trainable_parameters(self):
"""Evaluate the number of trainable parameters for the NN."""
return sum(v.numel() for v in self.parameters() if v.requires_grad)