__all__ = ["LossHistory", "Model", "TrainState"]
import pickle
from collections import OrderedDict
import numpy as np
from . import config
from . import display
from . import gradients as grad
from . import losses as losses_module
from . import metrics as metrics_module
from . import optimizers
from . import utils
from .backend import backend_name, tf, torch, jax, paddle
from .callbacks import CallbackList
from .utils import list_to_str
[docs]
class Model:
"""A ``Model`` trains a ``NN`` on a ``Data``.
Args:
data: ``deepxde.data.Data`` instance.
net: ``deepxde.nn.NN`` instance.
"""
def __init__(self, data, net):
self.data = data
self.net = net
self.opt_name = None
self.batch_size = None
self.loss_weights = None
self.callbacks = None
self.metrics = None
self.external_trainable_variables = []
self.train_state = TrainState()
self.losshistory = LossHistory()
self.stop_training = False
# Backend-dependent attributes
self.opt = None
# Tensor or callable
self.outputs = None
self.outputs_losses_train = None
self.outputs_losses_test = None
self.train_step = None
if backend_name == "tensorflow.compat.v1":
self.sess = None
self.saver = None
elif backend_name in ["pytorch", "paddle"]:
self.lr_scheduler = None
elif backend_name == "jax":
self.opt_state = None
self.params = None
[docs]
@utils.timing
def compile(
self,
optimizer,
lr=None,
loss="MSE",
metrics=None,
decay=None,
loss_weights=None,
external_trainable_variables=None,
verbose=1,
):
"""Configures the model for training.
Args:
optimizer: String name of an optimizer, or a backend optimizer class
instance.
lr (float): The learning rate. For L-BFGS, use
``dde.optimizers.set_LBFGS_options`` to set the hyperparameters.
loss: If the same loss is used for all errors, then `loss` is a String name
of a loss function or a loss function. If different errors use
different losses, then `loss` is a list whose size is equal to the
number of errors.
metrics: List of metrics to be evaluated by the model during training.
decay (tuple): Name and parameters of decay to the initial learning rate.
One of the following options:
- For backend TensorFlow 1.x:
- `inverse_time_decay <https://www.tensorflow.org/api_docs/python/tf/compat/v1/train/inverse_time_decay>`_: ("inverse time", decay_steps, decay_rate)
- `cosine_decay <https://www.tensorflow.org/api_docs/python/tf/compat/v1/train/cosine_decay>`_: ("cosine", decay_steps, alpha)
- For backend TensorFlow 2.x:
- `InverseTimeDecay <https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/schedules/InverseTimeDecay>`_: ("inverse time", decay_steps, decay_rate)
- `CosineDecay <https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/schedules/CosineDecay>`_: ("cosine", decay_steps, alpha)
- For backend PyTorch:
- `StepLR <https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.StepLR.html>`_: ("step", step_size, gamma)
- `CosineAnnealingLR <https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.CosineAnnealingLR.html>`_: ("cosine", T_max, eta_min)
- `InverseTimeLR <https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/schedules/InverseTimeDecay>`_: ("inverse time", decay_steps, decay_rate)
- `ExponentialLR <https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.ExponentialLR.html>`_: ("exponential", gamma)
- `LambdaLR <https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.LambdaLR.html>`_: ("lambda", lambda_fn: Callable[[step], float])
- For backend PaddlePaddle:
- `InverseTimeDecay
<https://www.paddlepaddle.org.cn/documentation/docs/en/develop/api/paddle/optimizer/lr/InverseTimeDecay_en.html>`_:
("inverse time", gamma)
loss_weights: A list specifying scalar coefficients (Python floats) to
weight the loss contributions. The loss value that will be minimized by
the model will then be the weighted sum of all individual losses,
weighted by the `loss_weights` coefficients.
external_trainable_variables: A trainable ``dde.Variable`` object or a list
of trainable ``dde.Variable`` objects. The unknown parameters in the
physics systems that need to be recovered. Regularization will not be
applied to these variables. If the backend is tensorflow.compat.v1,
`external_trainable_variables` is ignored, and all trainable ``dde.Variable``
objects are automatically collected.
verbose (Integer): Controls the verbosity of the compile process.
"""
if verbose > 0 and config.rank == 0:
print("Compiling model...")
self.opt_name = optimizer
loss_fn = losses_module.get(loss)
self.loss_weights = loss_weights
if external_trainable_variables is None:
self.external_trainable_variables = []
else:
if backend_name == "tensorflow.compat.v1":
print(
"Warning: For the backend tensorflow.compat.v1, "
"`external_trainable_variables` is ignored, and all trainable "
"``tf.Variable`` objects are automatically collected."
)
if not isinstance(external_trainable_variables, list):
external_trainable_variables = [external_trainable_variables]
self.external_trainable_variables = external_trainable_variables
if backend_name == "tensorflow.compat.v1":
self._compile_tensorflow_compat_v1(lr, loss_fn, decay)
elif backend_name == "tensorflow":
self._compile_tensorflow(lr, loss_fn, decay)
elif backend_name == "pytorch":
self._compile_pytorch(lr, loss_fn, decay)
elif backend_name == "jax":
self._compile_jax(lr, loss_fn, decay)
elif backend_name == "paddle":
self._compile_paddle(lr, loss_fn, decay)
# metrics may use model variables such as self.net, and thus are instantiated
# after backend compile.
metrics = metrics or []
self.metrics = [metrics_module.get(m) for m in metrics]
def _compile_tensorflow_compat_v1(self, lr, loss_fn, decay):
"""tensorflow.compat.v1"""
if not self.net.built:
self.net.build()
if self.sess is None:
if config.xla_jit:
cfg = tf.ConfigProto()
cfg.graph_options.optimizer_options.global_jit_level = (
tf.OptimizerOptions.ON_2
)
self.sess = tf.Session(config=cfg)
elif config.hvd is not None:
cfg = tf.ConfigProto()
cfg.gpu_options.visible_device_list = str(config.rank)
self.sess = tf.Session(config=cfg)
else:
self.sess = tf.Session()
self.saver = tf.train.Saver(max_to_keep=None)
def losses(losses_fn):
# Data losses
losses = losses_fn(
self.net.targets, self.net.outputs, loss_fn, self.net.inputs, self
)
if not isinstance(losses, list):
losses = [losses]
# Regularization loss
if self.net.regularizer is not None:
losses.append(
tf.losses.get_regularization_loss() + self.net.regularization_loss
)
losses = tf.convert_to_tensor(losses)
# Weighted losses
if self.loss_weights is not None:
losses *= self.loss_weights
return losses
losses_train = losses(self.data.losses_train)
losses_test = losses(self.data.losses_test)
total_loss = tf.math.reduce_sum(losses_train)
# Tensors
self.outputs = self.net.outputs
self.outputs_losses_train = [self.net.outputs, losses_train]
self.outputs_losses_test = [self.net.outputs, losses_test]
self.train_step = optimizers.get(
total_loss, self.opt_name, learning_rate=lr, decay=decay
)
def _compile_tensorflow(self, lr, loss_fn, decay):
"""tensorflow"""
@tf.function(jit_compile=config.xla_jit)
def outputs(training, inputs):
return self.net(inputs, training=training)
def outputs_losses(training, inputs, targets, auxiliary_vars, losses_fn):
self.net.auxiliary_vars = auxiliary_vars
# Don't call outputs() decorated by @tf.function above, otherwise the
# gradient of outputs wrt inputs will be lost here.
outputs_ = self.net(inputs, training=training)
# Data losses
# if forward-mode AD is used, then a forward call needs to be passed
aux = [self.net] if config.autodiff == "forward" else None
losses = losses_fn(targets, outputs_, loss_fn, inputs, self, aux=aux)
if not isinstance(losses, list):
losses = [losses]
# Regularization loss
if self.net.regularizer is not None:
losses += [tf.math.reduce_sum(self.net.losses)]
losses = tf.convert_to_tensor(losses)
# Weighted losses
if self.loss_weights is not None:
losses *= self.loss_weights
return outputs_, losses
@tf.function(jit_compile=config.xla_jit)
def outputs_losses_train(inputs, targets, auxiliary_vars):
return outputs_losses(
True, inputs, targets, auxiliary_vars, self.data.losses_train
)
@tf.function(jit_compile=config.xla_jit)
def outputs_losses_test(inputs, targets, auxiliary_vars):
return outputs_losses(
False, inputs, targets, auxiliary_vars, self.data.losses_test
)
opt = optimizers.get(self.opt_name, learning_rate=lr, decay=decay)
@tf.function(jit_compile=config.xla_jit)
def train_step(inputs, targets, auxiliary_vars):
# inputs and targets are np.ndarray and automatically converted to Tensor.
with tf.GradientTape() as tape:
losses = outputs_losses_train(inputs, targets, auxiliary_vars)[1]
total_loss = tf.math.reduce_sum(losses)
trainable_variables = (
self.net.trainable_variables + self.external_trainable_variables
)
grads = tape.gradient(total_loss, trainable_variables)
opt.apply_gradients(zip(grads, trainable_variables))
def train_step_tfp(
inputs, targets, auxiliary_vars, previous_optimizer_results=None
):
def build_loss():
losses = outputs_losses_train(inputs, targets, auxiliary_vars)[1]
return tf.math.reduce_sum(losses)
trainable_variables = (
self.net.trainable_variables + self.external_trainable_variables
)
return opt(trainable_variables, build_loss, previous_optimizer_results)
# Callables
self.outputs = outputs
self.outputs_losses_train = outputs_losses_train
self.outputs_losses_test = outputs_losses_test
self.train_step = (
train_step
if not optimizers.is_external_optimizer(self.opt_name)
else train_step_tfp
)
def _compile_pytorch(self, lr, loss_fn, decay):
"""pytorch"""
def outputs(training, inputs):
self.net.train(mode=training)
with torch.no_grad():
if isinstance(inputs, tuple):
inputs = tuple(
map(lambda x: torch.as_tensor(x).requires_grad_(), inputs)
)
else:
inputs = torch.as_tensor(inputs)
inputs.requires_grad_()
# Clear cached Jacobians and Hessians.
grad.clear()
return self.net(inputs)
def outputs_losses(training, inputs, targets, auxiliary_vars, losses_fn):
self.net.auxiliary_vars = None
if auxiliary_vars is not None:
self.net.auxiliary_vars = torch.as_tensor(auxiliary_vars)
self.net.train(mode=training)
if isinstance(inputs, tuple):
inputs = tuple(
map(lambda x: torch.as_tensor(x).requires_grad_(), inputs)
)
else:
inputs = torch.as_tensor(inputs)
inputs.requires_grad_()
outputs_ = self.net(inputs)
# Data losses
if targets is not None:
targets = torch.as_tensor(targets)
# if forward-mode AD is used, then a forward call needs to be passed
aux = [self.net] if config.autodiff == "forward" else None
losses = losses_fn(targets, outputs_, loss_fn, inputs, self, aux=aux)
if not isinstance(losses, list):
losses = [losses]
losses = torch.stack(losses)
# Weighted losses
if self.loss_weights is not None:
losses *= torch.as_tensor(self.loss_weights)
# Clear cached Jacobians and Hessians.
grad.clear()
return outputs_, losses
def outputs_losses_train(inputs, targets, auxiliary_vars):
return outputs_losses(
True, inputs, targets, auxiliary_vars, self.data.losses_train
)
def outputs_losses_test(inputs, targets, auxiliary_vars):
return outputs_losses(
False, inputs, targets, auxiliary_vars, self.data.losses_test
)
weight_decay = 0
if self.net.regularizer is not None:
if self.net.regularizer[0] != "l2":
raise NotImplementedError(
f"{self.net.regularizer[0]} regularization to be implemented for "
"backend pytorch"
)
weight_decay = self.net.regularizer[1]
optimizer_params = self.net.parameters()
if self.external_trainable_variables:
# L-BFGS doesn't support per-parameter options.
if self.opt_name in ["L-BFGS", "L-BFGS-B"]:
optimizer_params = (
list(optimizer_params) + self.external_trainable_variables
)
if weight_decay > 0:
print(
"Warning: L2 regularization will also be applied to external_trainable_variables. "
"Ensure this is intended behavior."
)
else:
optimizer_params = [
{"params": optimizer_params},
{"params": self.external_trainable_variables, "weight_decay": 0},
]
self.opt, self.lr_scheduler = optimizers.get(
optimizer_params,
self.opt_name,
learning_rate=lr,
decay=decay,
weight_decay=weight_decay,
)
def train_step(inputs, targets, auxiliary_vars):
def closure():
losses = outputs_losses_train(inputs, targets, auxiliary_vars)[1]
total_loss = torch.sum(losses)
self.opt.zero_grad()
total_loss.backward()
return total_loss
self.opt.step(closure)
if self.lr_scheduler is not None:
self.lr_scheduler.step()
def train_step_nncg(inputs, targets, auxiliary_vars):
def closure():
losses = outputs_losses_train(inputs, targets, auxiliary_vars)[1]
total_loss = torch.sum(losses)
self.opt.zero_grad()
return total_loss
self.opt.step(closure)
if self.lr_scheduler is not None:
self.lr_scheduler.step()
# Callables
self.outputs = outputs
self.outputs_losses_train = outputs_losses_train
self.outputs_losses_test = outputs_losses_test
self.train_step = train_step if self.opt_name != "NNCG" else train_step_nncg
def _compile_jax(self, lr, loss_fn, decay):
"""jax"""
# Initialize the network's parameters
if self.params is None:
key = jax.random.PRNGKey(config.jax_random_seed)
self.net.params = self.net.init(key, self.data.test()[0])
external_trainable_variables_arr = [
var.value for var in self.external_trainable_variables
]
self.params = [self.net.params, external_trainable_variables_arr]
# TODO: learning rate decay
self.opt = optimizers.get(self.opt_name, learning_rate=lr)
self.opt_state = self.opt.init(self.params)
@jax.jit
def outputs(params, training, inputs):
return self.net.apply(params, inputs, training=training)
def outputs_losses(params, training, inputs, targets, losses_fn):
nn_params, ext_params = params
# TODO: Add auxiliary vars
def outputs_fn(inputs):
return self.net.apply(nn_params, inputs, training=training)
outputs_ = self.net.apply(nn_params, inputs, training=training)
# Data losses
# We use aux so that self.data.losses is a pure function.
aux = [outputs_fn, ext_params] if ext_params else [outputs_fn]
losses = losses_fn(targets, outputs_, loss_fn, inputs, self, aux=aux)
# TODO: Add regularization loss
if not isinstance(losses, list):
losses = [losses]
losses = jax.numpy.asarray(losses)
if self.loss_weights is not None:
losses *= jax.numpy.asarray(self.loss_weights)
return outputs_, losses
@jax.jit
def outputs_losses_train(params, inputs, targets):
return outputs_losses(params, True, inputs, targets, self.data.losses_train)
@jax.jit
def outputs_losses_test(params, inputs, targets):
return outputs_losses(params, False, inputs, targets, self.data.losses_test)
@jax.jit
def train_step(params, opt_state, inputs, targets):
def loss_function(params):
return jax.numpy.sum(outputs_losses_train(params, inputs, targets)[1])
grad_fn = jax.grad(loss_function)
grads = grad_fn(params)
updates, new_opt_state = self.opt.update(grads, opt_state)
new_params = optimizers.apply_updates(params, updates)
return new_params, new_opt_state
# Pure functions
self.outputs = outputs
self.outputs_losses_train = outputs_losses_train
self.outputs_losses_test = outputs_losses_test
self.train_step = train_step
def _compile_paddle(self, lr, loss_fn, decay):
"""paddle"""
def outputs(training, inputs):
if training:
self.net.train()
else:
self.net.eval()
with paddle.no_grad():
if isinstance(inputs, tuple):
inputs = tuple(
map(lambda x: paddle.to_tensor(x, stop_gradient=False), inputs)
)
else:
inputs = paddle.to_tensor(inputs, stop_gradient=False)
return self.net(inputs)
def outputs_losses(training, inputs, targets, auxiliary_vars, losses_fn):
self.net.auxiliary_vars = auxiliary_vars
if training:
self.net.train()
else:
self.net.eval()
if isinstance(inputs, tuple):
inputs = tuple(
map(lambda x: paddle.to_tensor(x, stop_gradient=False), inputs)
)
else:
inputs = paddle.to_tensor(inputs, stop_gradient=False)
outputs_ = self.net(inputs)
# Data losses
if targets is not None:
targets = paddle.to_tensor(targets)
losses = losses_fn(targets, outputs_, loss_fn, inputs, self)
if not isinstance(losses, list):
losses = [losses]
# TODO: regularization
losses = paddle.stack(losses, axis=0)
# Weighted losses
if self.loss_weights is not None:
losses *= paddle.to_tensor(self.loss_weights, dtype=losses.dtype)
# Clear cached Jacobians and Hessians.
grad.clear()
return outputs_, losses
def outputs_losses_train(inputs, targets, auxiliary_vars):
return outputs_losses(
True, inputs, targets, auxiliary_vars, self.data.losses_train
)
def outputs_losses_test(inputs, targets, auxiliary_vars):
return outputs_losses(
False, inputs, targets, auxiliary_vars, self.data.losses_test
)
trainable_variables = (
list(self.net.parameters()) + self.external_trainable_variables
)
self.opt = optimizers.get(
trainable_variables, self.opt_name, learning_rate=lr, decay=decay
)
def train_step(inputs, targets, auxiliary_vars):
losses = outputs_losses_train(inputs, targets, auxiliary_vars)[1]
total_loss = paddle.sum(losses)
total_loss.backward()
self.opt.step()
self.opt.clear_grad()
if self.lr_scheduler is not None:
self.lr_scheduler.step()
def train_step_lbfgs(inputs, targets, auxiliary_vars):
def closure():
losses = outputs_losses_train(inputs, targets, auxiliary_vars)[1]
total_loss = paddle.sum(losses)
self.opt.clear_grad()
total_loss.backward()
return total_loss
self.opt.step(closure)
# Callables
self.outputs = outputs
self.outputs_losses_train = outputs_losses_train
self.outputs_losses_test = outputs_losses_test
self.train_step = (
train_step
if not optimizers.is_external_optimizer(self.opt_name)
else train_step_lbfgs
)
def _outputs(self, training, inputs):
if backend_name == "tensorflow.compat.v1":
feed_dict = self.net.feed_dict(training, inputs)
return self.sess.run(self.outputs, feed_dict=feed_dict)
if backend_name in ["tensorflow", "pytorch", "paddle"]:
outs = self.outputs(training, inputs)
elif backend_name == "jax":
outs = self.outputs(self.net.params, training, inputs)
return utils.to_numpy(outs)
def _outputs_losses(self, training, inputs, targets, auxiliary_vars):
if training:
outputs_losses = self.outputs_losses_train
else:
outputs_losses = self.outputs_losses_test
if backend_name == "tensorflow.compat.v1":
feed_dict = self.net.feed_dict(training, inputs, targets, auxiliary_vars)
return self.sess.run(outputs_losses, feed_dict=feed_dict)
if backend_name == "tensorflow":
outs = outputs_losses(inputs, targets, auxiliary_vars)
elif backend_name == "pytorch":
self.net.requires_grad_(requires_grad=False)
outs = outputs_losses(inputs, targets, auxiliary_vars)
self.net.requires_grad_()
elif backend_name == "jax":
# TODO: auxiliary_vars
outs = outputs_losses(self.params, inputs, targets)
elif backend_name == "paddle":
outs = outputs_losses(inputs, targets, auxiliary_vars)
return utils.to_numpy(outs[0]), utils.to_numpy(outs[1])
def _train_step(self, inputs, targets, auxiliary_vars):
if backend_name == "tensorflow.compat.v1":
feed_dict = self.net.feed_dict(True, inputs, targets, auxiliary_vars)
self.sess.run(self.train_step, feed_dict=feed_dict)
elif backend_name in ["tensorflow", "paddle"]:
self.train_step(inputs, targets, auxiliary_vars)
elif backend_name == "pytorch":
self.train_step(inputs, targets, auxiliary_vars)
elif backend_name == "jax":
# TODO: auxiliary_vars
self.params, self.opt_state = self.train_step(
self.params, self.opt_state, inputs, targets
)
self.net.params, external_trainable_variables = self.params
for i, var in enumerate(self.external_trainable_variables):
var.value = external_trainable_variables[i]
[docs]
@utils.timing
def train(
self,
iterations=None,
batch_size=None,
display_every=1000,
disregard_previous_best=False,
callbacks=None,
model_restore_path=None,
model_save_path=None,
epochs=None,
verbose=1,
):
"""Trains the model.
Args:
iterations (Integer): Number of iterations to train the model, i.e., number
of times the network weights are updated.
batch_size: Integer, tuple, or ``None``.
- If you solve PDEs via ``dde.data.PDE`` or ``dde.data.TimePDE``, do not use `batch_size`, and instead use
`dde.callbacks.PDEPointResampler
<https://deepxde.readthedocs.io/en/latest/modules/deepxde.html#deepxde.callbacks.PDEPointResampler>`_,
see an `example <https://github.com/lululxvi/deepxde/blob/master/examples/pinn_forward/diffusion_1d_resample.py>`_.
- For DeepONet in the format of Cartesian product, if `batch_size` is an Integer,
then it is the batch size for the branch input; if you want to also use mini-batch for the trunk net input,
set `batch_size` as a tuple, where the fist number is the batch size for the branch net input
and the second number is the batch size for the trunk net input.
display_every (Integer): Print the loss and metrics every this steps.
disregard_previous_best: If ``True``, disregard the previous saved best
model.
callbacks: List of ``dde.callbacks.Callback`` instances. List of callbacks
to apply during training.
model_restore_path (String): Path where parameters were previously saved.
model_save_path (String): Prefix of filenames created for the checkpoint.
epochs (Integer): Deprecated alias to `iterations`. This will be removed in
a future version.
verbose (Integer): Controls the verbosity of the train process.
"""
if iterations is None and epochs is not None:
print(
"Warning: epochs is deprecated and will be removed in a future version."
" Use iterations instead."
)
iterations = epochs
self.batch_size = batch_size
self.callbacks = CallbackList(callbacks=callbacks)
self.callbacks.set_model(self)
if disregard_previous_best:
self.train_state.disregard_best()
if backend_name == "tensorflow.compat.v1":
if self.train_state.step == 0:
self.sess.run(tf.global_variables_initializer())
if config.hvd is not None:
bcast = config.hvd.broadcast_global_variables(0)
self.sess.run(bcast)
else:
utils.guarantee_initialized_variables(self.sess)
if model_restore_path is not None:
self.restore(model_restore_path, verbose=1)
if verbose > 0 and config.rank == 0:
print("Training model...\n")
self.stop_training = False
self.train_state.set_data_train(*self.data.train_next_batch(self.batch_size))
self.train_state.set_data_test(*self.data.test())
self._test(verbose=verbose)
self.callbacks.on_train_begin()
if optimizers.is_external_optimizer(self.opt_name):
if backend_name == "tensorflow.compat.v1":
self._train_tensorflow_compat_v1_scipy(display_every, verbose=verbose)
elif backend_name == "tensorflow":
self._train_tensorflow_tfp(verbose=verbose)
elif backend_name == "pytorch":
if self.opt_name == "L-BFGS":
self._train_pytorch_lbfgs(verbose=verbose)
elif self.opt_name == "NNCG":
self._train_sgd(iterations, display_every, verbose=verbose)
elif backend_name == "paddle":
self._train_paddle_lbfgs(verbose=verbose)
else:
if iterations is None:
raise ValueError("No iterations for {}.".format(self.opt_name))
self._train_sgd(iterations, display_every, verbose=verbose)
self.callbacks.on_train_end()
if verbose > 0 and config.rank == 0:
print("")
display.training_display.summary(self.train_state)
if model_save_path is not None:
self.save(model_save_path, verbose=1)
return self.losshistory, self.train_state
def _train_sgd(self, iterations, display_every, verbose=1):
for i in range(iterations):
self.callbacks.on_epoch_begin()
self.callbacks.on_batch_begin()
self.train_state.set_data_train(
*self.data.train_next_batch(self.batch_size)
)
self._train_step(
self.train_state.X_train,
self.train_state.y_train,
self.train_state.train_aux_vars,
)
self.train_state.epoch += 1
self.train_state.step += 1
if self.train_state.step % display_every == 0 or i + 1 == iterations:
self._test(verbose=verbose)
self.callbacks.on_batch_end()
self.callbacks.on_epoch_end()
if self.stop_training:
break
def _train_tensorflow_compat_v1_scipy(self, display_every, verbose=1):
def loss_callback(loss_train, loss_test, *args):
self.train_state.epoch += 1
self.train_state.step += 1
if self.train_state.step % display_every == 0:
self.train_state.loss_train = loss_train
self.train_state.loss_test = loss_test
self.train_state.metrics_test = None
self.losshistory.append(
self.train_state.step,
self.train_state.loss_train,
self.train_state.loss_test,
None,
)
if verbose > 0:
display.training_display(self.train_state)
for cb in self.callbacks.callbacks:
if type(cb).__name__ == "VariableValue":
cb.epochs_since_last += 1
if cb.epochs_since_last >= cb.period:
cb.epochs_since_last = 0
print(
cb.model.train_state.epoch,
list_to_str(
[float(arg) for arg in args],
precision=cb.precision,
),
file=cb.file,
)
cb.file.flush()
self.train_state.set_data_train(*self.data.train_next_batch(self.batch_size))
feed_dict = self.net.feed_dict(
True,
self.train_state.X_train,
self.train_state.y_train,
self.train_state.train_aux_vars,
)
fetches = [self.outputs_losses_train[1], self.outputs_losses_test[1]]
if self.external_trainable_variables:
fetches += self.external_trainable_variables
self.train_step.minimize(
self.sess,
feed_dict=feed_dict,
fetches=fetches,
loss_callback=loss_callback,
)
self._test(verbose=verbose)
def _train_tensorflow_tfp(self, verbose=1):
# There is only one optimization step. If using multiple steps with/without
# previous_optimizer_results, L-BFGS failed to reach a small error. The reason
# could be that tfp.optimizer.lbfgs_minimize will start from scratch for each
# call.
n_iter = 0
while n_iter < optimizers.LBFGS_options["maxiter"]:
self.train_state.set_data_train(
*self.data.train_next_batch(self.batch_size)
)
results = self.train_step(
self.train_state.X_train,
self.train_state.y_train,
self.train_state.train_aux_vars,
)
n_iter += results.num_iterations.numpy()
self.train_state.epoch += results.num_iterations.numpy()
self.train_state.step += results.num_iterations.numpy()
self._test(verbose=verbose)
if results.converged or results.failed:
break
def _train_pytorch_lbfgs(self, verbose=1):
prev_n_iter = 0
while prev_n_iter < optimizers.LBFGS_options["maxiter"]:
self.callbacks.on_epoch_begin()
self.callbacks.on_batch_begin()
self.train_state.set_data_train(
*self.data.train_next_batch(self.batch_size)
)
self._train_step(
self.train_state.X_train,
self.train_state.y_train,
self.train_state.train_aux_vars,
)
n_iter = self.opt.state_dict()["state"][0]["n_iter"]
if prev_n_iter == n_iter - 1:
# Converged
break
self.train_state.epoch += n_iter - prev_n_iter
self.train_state.step += n_iter - prev_n_iter
prev_n_iter = n_iter
self._test(verbose=verbose)
self.callbacks.on_batch_end()
self.callbacks.on_epoch_end()
if self.stop_training:
break
def _train_paddle_lbfgs(self, verbose=1):
prev_n_iter = 0
while prev_n_iter < optimizers.LBFGS_options["maxiter"]:
self.callbacks.on_epoch_begin()
self.callbacks.on_batch_begin()
self.train_state.set_data_train(
*self.data.train_next_batch(self.batch_size)
)
self._train_step(
self.train_state.X_train,
self.train_state.y_train,
self.train_state.train_aux_vars,
)
n_iter = self.opt.state_dict()["state"]["n_iter"]
if prev_n_iter == n_iter - 1:
# Converged
break
self.train_state.epoch += n_iter - prev_n_iter
self.train_state.step += n_iter - prev_n_iter
prev_n_iter = n_iter
self._test(verbose=verbose)
self.callbacks.on_batch_end()
self.callbacks.on_epoch_end()
if self.stop_training:
break
def _test(self, verbose=1):
# TODO Now only print the training loss in rank 0. The correct way is to print the average training loss of all ranks.
(
self.train_state.y_pred_train,
self.train_state.loss_train,
) = self._outputs_losses(
True,
self.train_state.X_train,
self.train_state.y_train,
self.train_state.train_aux_vars,
)
self.train_state.y_pred_test, self.train_state.loss_test = self._outputs_losses(
False,
self.train_state.X_test,
self.train_state.y_test,
self.train_state.test_aux_vars,
)
if isinstance(self.train_state.y_test, (list, tuple)):
self.train_state.metrics_test = [
m(self.train_state.y_test[i], self.train_state.y_pred_test[i])
for m in self.metrics
for i in range(len(self.train_state.y_test))
]
else:
self.train_state.metrics_test = [
m(self.train_state.y_test, self.train_state.y_pred_test)
for m in self.metrics
]
self.train_state.update_best()
self.losshistory.append(
self.train_state.step,
self.train_state.loss_train,
self.train_state.loss_test,
self.train_state.metrics_test,
)
if (
np.isnan(self.train_state.loss_train).any()
or np.isnan(self.train_state.loss_test).any()
):
self.stop_training = True
if verbose > 0 and config.rank == 0:
display.training_display(self.train_state)
[docs]
def predict(self, x, operator=None, callbacks=None):
"""Generates predictions for the input samples. If `operator` is ``None``,
returns the network output, otherwise returns the output of the `operator`.
Args:
x: The network inputs. A Numpy array or a tuple of Numpy arrays.
operator: A function takes arguments (`inputs`, `outputs`) or (`inputs`,
`outputs`, `auxiliary_variables`) and outputs a tensor. `inputs` and
`outputs` are the network input and output tensors, respectively.
`auxiliary_variables` is the output of `auxiliary_var_function(x)`
in `dde.data.PDE`. `operator` is typically chosen as the PDE (used to
define `dde.data.PDE`) to predict the PDE residual.
callbacks: List of ``dde.callbacks.Callback`` instances. List of callbacks
to apply during prediction.
"""
if isinstance(x, tuple):
x = tuple(np.asarray(xi, dtype=config.real(np)) for xi in x)
else:
x = np.asarray(x, dtype=config.real(np))
callbacks = CallbackList(callbacks=callbacks)
callbacks.set_model(self)
callbacks.on_predict_begin()
if operator is None:
y = self._outputs(False, x)
callbacks.on_predict_end()
return y
# operator is not None
if utils.get_num_args(operator) == 3:
aux_vars = self.data.auxiliary_var_fn(x).astype(config.real(np))
if backend_name == "tensorflow.compat.v1":
if utils.get_num_args(operator) == 2:
op = operator(self.net.inputs, self.net.outputs)
feed_dict = self.net.feed_dict(False, x)
elif utils.get_num_args(operator) == 3:
op = operator(
self.net.inputs, self.net.outputs, self.net.auxiliary_vars
)
feed_dict = self.net.feed_dict(False, x, auxiliary_vars=aux_vars)
y = self.sess.run(op, feed_dict=feed_dict)
elif backend_name == "tensorflow":
if utils.get_num_args(operator) == 2:
@tf.function
def op(inputs):
y = self.net(inputs)
if config.autodiff == "forward":
y = (y, self.net)
return operator(inputs, y)
elif utils.get_num_args(operator) == 3:
@tf.function
def op(inputs):
y = self.net(inputs)
return operator(inputs, y, aux_vars)
y = op(x)
y = utils.to_numpy(y)
elif backend_name == "pytorch":
self.net.eval()
if isinstance(x, tuple):
inputs = tuple(map(lambda x: torch.as_tensor(x).requires_grad_(), x))
else:
inputs = torch.as_tensor(x).requires_grad_()
outputs = self.net(inputs)
if utils.get_num_args(operator) == 2:
if config.autodiff == "forward":
outputs = (outputs, self.net)
y = operator(inputs, outputs)
elif utils.get_num_args(operator) == 3:
# TODO: Pytorch backend Implementation of Auxiliary variables.
# y = operator(inputs, outputs, torch.as_tensor(aux_vars))
raise NotImplementedError(
"Model.predict() with auxiliary variable hasn't been implemented "
"for backend pytorch."
)
# Clear cached Jacobians and Hessians.
grad.clear()
y = utils.to_numpy(y)
elif backend_name == "jax":
if utils.get_num_args(operator) == 2:
@jax.jit
def op(inputs):
y_fn = lambda _x: self.net.apply(self.net.params, _x)
return operator(inputs, (y_fn(inputs), y_fn))
elif utils.get_num_args(operator) == 3:
# TODO: JAX backend Implementation of Auxiliary variables.
raise NotImplementedError(
"Model.predict() with auxiliary variable hasn't been implemented "
"for backend jax."
)
y = op(x)
y = utils.to_numpy(y)
elif backend_name == "paddle":
self.net.eval()
inputs = paddle.to_tensor(x, stop_gradient=False)
outputs = self.net(inputs)
if utils.get_num_args(operator) == 2:
y = operator(inputs, outputs)
elif utils.get_num_args(operator) == 3:
# TODO: Paddle backend Implementation of Auxiliary variables.
# y = operator(inputs, outputs, paddle.to_tensor(aux_vars))
raise NotImplementedError(
"Model.predict() with auxiliary variable hasn't been implemented "
"for backend paddle."
)
y = utils.to_numpy(y)
callbacks.on_predict_end()
return y
# def evaluate(self, x, y, callbacks=None):
# """Returns the loss values & metrics values for the model in test mode."""
# raise NotImplementedError(
# "Model.evaluate to be implemented. Alternatively, use Model.predict."
# )
[docs]
def state_dict(self):
"""Returns a dictionary containing all variables."""
if backend_name == "tensorflow.compat.v1":
destination = OrderedDict()
variables_names = [v.name for v in tf.global_variables()]
values = self.sess.run(variables_names)
for k, v in zip(variables_names, values):
destination[k] = v
elif backend_name == "tensorflow":
# user-provided variables
destination = {
f"external_trainable_variable:{i}": v
for (i, v) in enumerate(self.external_trainable_variables)
}
# the paramaters of the net
destination.update(self.net.get_weight_paths())
elif backend_name in ["pytorch", "paddle"]:
destination = self.net.state_dict()
else:
raise NotImplementedError(
"state_dict hasn't been implemented for this backend."
)
return destination
[docs]
def save(self, save_path, protocol="backend", verbose=0):
"""Saves all variables to a disk file.
Args:
save_path (string): Prefix of filenames to save the model file.
protocol (string): If `protocol` is "backend", save using the
backend-specific method.
- For "tensorflow.compat.v1", use `tf.train.Save <https://www.tensorflow.org/api_docs/python/tf/compat/v1/train/Saver#attributes>`_.
- For "tensorflow", use `tf.keras.Model.save_weights <https://www.tensorflow.org/api_docs/python/tf/keras/Model#save_weights>`_.
- For "pytorch", use `torch.save <https://pytorch.org/docs/stable/generated/torch.save.html>`_.
- For "paddle", use `paddle.save <https://www.paddlepaddle.org.cn/documentation/docs/en/api/paddle/save_en.html>`_.
If `protocol` is "pickle", save using the Python pickle module. Only the
protocol "backend" supports ``restore()``.
Returns:
string: Path where model is saved.
"""
save_path = f"{save_path}-{self.train_state.epoch}"
if protocol == "pickle":
save_path += ".pkl"
with open(save_path, "wb") as f:
pickle.dump(self.state_dict(), f)
elif protocol == "backend":
if backend_name == "tensorflow.compat.v1":
save_path += ".ckpt"
self.saver.save(self.sess, save_path)
elif backend_name == "tensorflow":
save_path += ".weights.h5"
self.net.save_weights(save_path)
elif backend_name == "pytorch":
save_path += ".pt"
checkpoint = {
"model_state_dict": self.net.state_dict(),
"optimizer_state_dict": self.opt.state_dict(),
}
torch.save(checkpoint, save_path)
elif backend_name == "paddle":
save_path += ".pdparams"
checkpoint = {
"model": self.net.state_dict(),
"opt": self.opt.state_dict(),
}
paddle.save(checkpoint, save_path)
else:
raise NotImplementedError(
"Model.save() hasn't been implemented for this backend."
)
if verbose > 0:
print(
"Epoch {}: saving model to {} ...\n".format(
self.train_state.epoch, save_path
)
)
return save_path
[docs]
def restore(self, save_path, device=None, verbose=0):
"""Restore all variables from a disk file.
Args:
save_path (string): Path where model was previously saved.
device (string, optional): Device to load the model on (e.g. "cpu","cuda:0"...). By default, the model is loaded on the device it was saved from.
"""
# TODO: backend tensorflow
if device is not None and backend_name != "pytorch":
print(
"Warning: device is only supported for backend pytorch. Model will be loaded on the device it was saved from."
)
if verbose > 0:
print("Restoring model from {} ...\n".format(save_path))
if backend_name == "tensorflow.compat.v1":
self.saver.restore(self.sess, save_path)
elif backend_name == "tensorflow":
self.net.load_weights(save_path)
elif backend_name == "pytorch":
if device is not None:
checkpoint = torch.load(save_path, map_location=torch.device(device))
else:
checkpoint = torch.load(save_path)
self.net.load_state_dict(checkpoint["model_state_dict"])
self.opt.load_state_dict(checkpoint["optimizer_state_dict"])
elif backend_name == "paddle":
checkpoint = paddle.load(save_path)
self.net.set_state_dict(checkpoint["model"])
self.opt.set_state_dict(checkpoint["opt"])
else:
raise NotImplementedError(
"Model.restore() hasn't been implemented for this backend."
)
[docs]
def print_model(self):
"""Prints all trainable variables."""
# TODO: backend tensorflow, pytorch
if backend_name != "tensorflow.compat.v1":
raise NotImplementedError(
"state_dict hasn't been implemented for this backend."
)
variables_names = [v.name for v in tf.trainable_variables()]
values = self.sess.run(variables_names)
for k, v in zip(variables_names, values):
print("Variable: {}, Shape: {}".format(k, v.shape))
print(v)
[docs]
class TrainState:
def __init__(self):
self.epoch = 0
self.step = 0
# Current data
self.X_train = None
self.y_train = None
self.train_aux_vars = None
self.X_test = None
self.y_test = None
self.test_aux_vars = None
# Results of current step
# Train results
self.loss_train = None
self.y_pred_train = None
# Test results
self.loss_test = None
self.y_pred_test = None
self.y_std_test = None
self.metrics_test = None
# The best results correspond to the min train loss
self.best_step = 0
self.best_loss_train = np.inf
self.best_loss_test = np.inf
self.best_y = None
self.best_ystd = None
self.best_metrics = None
[docs]
def set_data_train(self, X_train, y_train, train_aux_vars=None):
self.X_train = X_train
self.y_train = y_train
self.train_aux_vars = train_aux_vars
[docs]
def set_data_test(self, X_test, y_test, test_aux_vars=None):
self.X_test = X_test
self.y_test = y_test
self.test_aux_vars = test_aux_vars
[docs]
def update_best(self):
if self.best_loss_train > np.sum(self.loss_train):
self.best_step = self.step
self.best_loss_train = np.sum(self.loss_train)
self.best_loss_test = np.sum(self.loss_test)
self.best_y = self.y_pred_test
self.best_ystd = self.y_std_test
self.best_metrics = self.metrics_test
[docs]
def disregard_best(self):
self.best_loss_train = np.inf
[docs]
class LossHistory:
def __init__(self):
self.steps = []
self.loss_train = []
self.loss_test = []
self.metrics_test = []
[docs]
def append(self, step, loss_train, loss_test, metrics_test):
self.steps.append(step)
self.loss_train.append(loss_train)
if loss_test is None:
loss_test = self.loss_test[-1]
if metrics_test is None:
metrics_test = self.metrics_test[-1]
self.loss_test.append(loss_test)
self.metrics_test.append(metrics_test)