Source code for deepxde.callbacks

import sys
import time

import numpy as np

from . import config
from . import gradients as grad
from . import utils
from .backend import backend_name, jax, paddle, tf, torch


[docs] class Callback: """Callback base class. Attributes: model: instance of ``Model``. Reference of the model being trained. """ def __init__(self): self.model = None
[docs] def set_model(self, model): if model is not self.model: self.model = model self.init()
[docs] def init(self): """Init after setting a model."""
[docs] def on_epoch_begin(self): """Called at the beginning of every epoch."""
[docs] def on_epoch_end(self): """Called at the end of every epoch."""
[docs] def on_batch_begin(self): """Called at the beginning of every batch."""
[docs] def on_batch_end(self): """Called at the end of every batch."""
[docs] def on_train_begin(self): """Called at the beginning of model training."""
[docs] def on_train_end(self): """Called at the end of model training."""
[docs] def on_predict_begin(self): """Called at the beginning of prediction."""
[docs] def on_predict_end(self): """Called at the end of prediction."""
[docs] class CallbackList(Callback): """Container abstracting a list of callbacks. Args: callbacks: List of ``Callback`` instances. """ def __init__(self, callbacks=None): callbacks = callbacks or [] self.callbacks = list(callbacks) self.model = None
[docs] def set_model(self, model): self.model = model for callback in self.callbacks: callback.set_model(model)
[docs] def on_epoch_begin(self): for callback in self.callbacks: callback.on_epoch_begin()
[docs] def on_epoch_end(self): for callback in self.callbacks: callback.on_epoch_end()
[docs] def on_batch_begin(self): for callback in self.callbacks: callback.on_batch_begin()
[docs] def on_batch_end(self): for callback in self.callbacks: callback.on_batch_end()
[docs] def on_train_begin(self): for callback in self.callbacks: callback.on_train_begin()
[docs] def on_train_end(self): for callback in self.callbacks: callback.on_train_end()
[docs] def on_predict_begin(self): for callback in self.callbacks: callback.on_predict_begin()
[docs] def on_predict_end(self): for callback in self.callbacks: callback.on_predict_end()
[docs] def append(self, callback): if not isinstance(callback, Callback): raise Exception(str(callback) + " is an invalid Callback object") self.callbacks.append(callback)
[docs] class ModelCheckpoint(Callback): """Save the model after every epoch. Args: filepath (string): Prefix of filenames to save the model file. verbose: Verbosity mode, 0 or 1. save_better_only: If True, only save a better model according to the quantity monitored. Model is only checked at validation step according to ``display_every`` in ``Model.train``. period: Interval (number of epochs) between checkpoints. monitor: The loss function that is monitored. Either 'train loss' or 'test loss'. """ def __init__( self, filepath, verbose=0, save_better_only=False, period=1, monitor="train loss", ): super().__init__() self.filepath = filepath self.verbose = verbose self.save_better_only = save_better_only self.period = period self.monitor = monitor self.monitor_op = np.less self.epochs_since_last_save = 0 self.best = np.inf
[docs] def on_epoch_end(self): self.epochs_since_last_save += 1 if self.epochs_since_last_save < self.period: return self.epochs_since_last_save = 0 if self.save_better_only: current = self.get_monitor_value() if self.monitor_op(current, self.best): save_path = self.model.save(self.filepath, verbose=0) if self.verbose > 0: print( "Epoch {}: {} improved from {:.2e} to {:.2e}, saving model to {} ...\n".format( self.model.train_state.epoch, self.monitor, self.best, current, save_path, ) ) self.best = current else: self.model.save(self.filepath, verbose=self.verbose)
[docs] def get_monitor_value(self): if self.monitor == "train loss": result = sum(self.model.train_state.loss_train) elif self.monitor == "test loss": result = sum(self.model.train_state.loss_test) else: raise ValueError("The specified monitor function is incorrect.") return result
[docs] class EarlyStopping(Callback): """Stop training when a monitored quantity (training or testing loss) has stopped improving. Only checked at validation step according to ``display_every`` in ``Model.train``. Args: min_delta: Minimum change in the monitored quantity to qualify as an improvement, i.e. an absolute change of less than min_delta, will count as no improvement. patience: Number of epochs with no improvement after which training will be stopped. baseline: Baseline value for the monitored quantity to reach. Training will stop if the model doesn't show improvement over the baseline. monitor: The loss function that is monitored. Either 'loss_train' or 'loss_test' start_from_epoch: Number of epochs to wait before starting to monitor improvement. This allows for a warm-up period in which no improvement is expected and thus training will not be stopped. """ def __init__( self, min_delta=0, patience=0, baseline=None, monitor="loss_train", start_from_epoch=0, ): super().__init__() self.baseline = baseline self.monitor = monitor self.patience = patience self.min_delta = min_delta self.wait = 0 self.stopped_epoch = 0 self.start_from_epoch = start_from_epoch self.monitor_op = np.less self.min_delta *= -1
[docs] def on_train_begin(self): # Allow instances to be re-used self.wait = 0 self.stopped_epoch = 0 if self.baseline is not None: self.best = self.baseline else: self.best = np.inf if self.monitor_op == np.less else -np.inf
[docs] def on_epoch_end(self): if self.model.train_state.epoch < self.start_from_epoch: return current = self.get_monitor_value() if self.monitor_op(current - self.min_delta, self.best): self.best = current self.wait = 0 else: self.wait += 1 if self.wait >= self.patience: self.stopped_epoch = self.model.train_state.epoch self.model.stop_training = True
[docs] def on_train_end(self): if self.stopped_epoch > 0: print("Epoch {}: early stopping".format(self.stopped_epoch))
[docs] def get_monitor_value(self): if self.monitor == "loss_train": result = sum(self.model.train_state.loss_train) elif self.monitor == "loss_test": result = sum(self.model.train_state.loss_test) else: raise ValueError("The specified monitor function is incorrect.") return result
[docs] class Timer(Callback): """Stop training when training time reaches the threshold. This Timer starts after the first call of `on_train_begin`. Args: available_time (float): Total time (in minutes) available for the training. """ def __init__(self, available_time): super().__init__() self.threshold = available_time * 60 # convert to seconds self.t_start = None
[docs] def on_train_begin(self): if self.t_start is None: self.t_start = time.time()
[docs] def on_epoch_end(self): if time.time() - self.t_start > self.threshold: self.model.stop_training = True print( "\nStop training as time used up. time used: {:.1f} mins, epoch trained: {}".format( (time.time() - self.t_start) / 60, self.model.train_state.epoch ) )
[docs] class DropoutUncertainty(Callback): """Uncertainty estimation via MC dropout. References: `Y. Gal, & Z. Ghahramani. Dropout as a Bayesian approximation: Representing model uncertainty in deep learning. International Conference on Machine Learning, 2016 <https://arxiv.org/abs/1506.02142>`_. Warning: This cannot be used together with other techniques that have different behaviors during training and testing, such as batch normalization. """ def __init__(self, period=1000): super().__init__() self.period = period self.epochs_since_last = 0
[docs] def on_epoch_end(self): self.epochs_since_last += 1 if self.epochs_since_last >= self.period: self.epochs_since_last = 0 y_preds = [] for _ in range(1000): y_pred_test_one = self.model._outputs( True, self.model.train_state.X_test ) y_preds.append(y_pred_test_one) self.model.train_state.y_std_test = np.std(y_preds, axis=0)
[docs] def on_train_end(self): self.on_epoch_end()
[docs] class VariableValue(Callback): """Get the variable values. Args: var_list: A `TensorFlow Variable <https://www.tensorflow.org/api_docs/python/tf/Variable>`_ or a list of TensorFlow Variable. period (int): Interval (number of epochs) between checking values. filename (string): Output the values to the file `filename`. The file is kept open to allow instances to be re-used. If ``None``, output to the screen. precision (int): The precision of variables to display. """ def __init__(self, var_list, period=1, filename=None, precision=2): super().__init__() self.var_list = var_list if isinstance(var_list, list) else [var_list] self.period = period self.precision = precision self.file = sys.stdout if filename is None else open(filename, "w", buffering=1) self.value = None self.epochs_since_last = 0
[docs] def on_train_begin(self): if backend_name == "tensorflow.compat.v1": self.value = self.model.sess.run(self.var_list) elif backend_name == "tensorflow": self.value = [var.numpy() for var in self.var_list] elif backend_name in ["pytorch", "paddle"]: self.value = [var.detach().item() for var in self.var_list] elif backend_name == "jax": self.value = [var.value for var in self.var_list] print( self.model.train_state.epoch, utils.list_to_str(self.value, precision=self.precision), file=self.file, ) self.file.flush()
[docs] def on_epoch_end(self): self.epochs_since_last += 1 if self.epochs_since_last >= self.period: self.epochs_since_last = 0 self.on_train_begin()
[docs] def on_train_end(self): if not self.epochs_since_last == 0: self.on_train_begin()
[docs] def get_value(self): """Return the variable values.""" return self.value
[docs] class OperatorPredictor(Callback): """Generates operator values for the input samples. Args: x: The input data. op: The operator with inputs (x, y). period (int): Interval (number of epochs) between checking values. filename (string): Output the values to the file `filename`. The file is kept open to allow instances to be re-used. If ``None``, output to the screen. precision (int): The precision of variables to display. """ def __init__(self, x, op, period=1, filename=None, precision=2): super().__init__() self.x = x self.op = op self.period = period self.precision = precision self.file = sys.stdout if filename is None else open(filename, "w", buffering=1) self.value = None self.epochs_since_last = 0
[docs] def init(self): if backend_name == "tensorflow.compat.v1": self.tf_op = self.op(self.model.net.inputs, self.model.net.outputs) elif backend_name == "tensorflow": @tf.function def op(inputs): y = self.model.net(inputs) return self.op(inputs, y) self.tf_op = op elif backend_name == "pytorch": self.x = torch.as_tensor(self.x) self.x.requires_grad_() elif backend_name == "jax": @jax.jit def op(inputs, params): y_fn = lambda _x: self.model.net.apply(params, _x) return self.op(inputs, (y_fn(inputs), y_fn)) self.jax_op = op elif backend_name == "paddle": self.x = paddle.to_tensor(self.x, stop_gradient=False)
[docs] def on_train_begin(self): self.on_predict_end() print( self.model.train_state.epoch, utils.list_to_str(self.value.flatten().tolist(), precision=self.precision), file=self.file, ) self.file.flush()
[docs] def on_train_end(self): if not self.epochs_since_last == 0: self.on_train_begin()
[docs] def on_epoch_end(self): self.epochs_since_last += 1 if self.epochs_since_last >= self.period: self.epochs_since_last = 0 self.on_train_begin()
[docs] def on_predict_end(self): if backend_name == "tensorflow.compat.v1": self.value = self.model.sess.run( self.tf_op, feed_dict=self.model.net.feed_dict(False, self.x) ) elif backend_name == "tensorflow": self.value = utils.to_numpy(self.tf_op(self.x)) elif backend_name == "pytorch": self.model.net.eval() outputs = self.model.net(self.x) self.value = utils.to_numpy(self.op(self.x, outputs)) elif backend_name == "jax": self.value = utils.to_numpy(self.jax_op(self.x, self.model.net.params)) elif backend_name == "paddle": self.model.net.eval() outputs = self.model.net(self.x) self.value = utils.to_numpy(self.op(self.x, outputs))
[docs] def get_value(self): return self.value
[docs] class FirstDerivative(OperatorPredictor): """Generates the first order derivative of the outputs with respect to the inputs. Args: x: The input data. """ def __init__(self, x, component_x=0, component_y=0): def first_derivative(x, y): return grad.jacobian(y, x, i=component_y, j=component_x) super().__init__(x, first_derivative)
[docs] class MovieDumper(Callback): """Dump a movie to show the training progress of the function along a line. Args: spectrum: If True, dump the spectrum of the Fourier transform. """ def __init__( self, filename, x1, x2, num_points=100, period=1, component=0, save_spectrum=False, y_reference=None, ): super().__init__() self.filename = filename x1 = np.array(x1) x2 = np.array(x2) self.x = ( x1 + (x2 - x1) / (num_points - 1) * np.arange(num_points)[:, None] ).astype(dtype=config.real(np)) self.period = period self.component = component self.save_spectrum = save_spectrum self.y_reference = y_reference self.y = [] self.spectrum = [] self.epochs_since_last_save = 0
[docs] def on_train_begin(self): self.y.append(self.model._outputs(False, self.x)[:, self.component]) if self.save_spectrum: A = np.fft.rfft(self.y[-1]) self.spectrum.append(np.abs(A))
[docs] def on_epoch_end(self): self.epochs_since_last_save += 1 if self.epochs_since_last_save >= self.period: self.epochs_since_last_save = 0 self.on_train_begin()
[docs] def on_train_end(self): fname_x = self.filename + "_x.txt" fname_y = self.filename + "_y.txt" fname_movie = self.filename + "_y.gif" print( "\nSaving the movie of function to {}, {}, {}...".format( fname_x, fname_y, fname_movie ) ) np.savetxt(fname_x, self.x) np.savetxt(fname_y, np.array(self.y)) if self.y_reference is None: utils.save_animation(fname_movie, np.ravel(self.x), self.y) else: y_reference = np.ravel(self.y_reference(self.x)) utils.save_animation( fname_movie, np.ravel(self.x), self.y, y_reference=y_reference ) if self.save_spectrum: fname_spec = self.filename + "_spectrum.txt" fname_movie = self.filename + "_spectrum.gif" print( "Saving the movie of spectrum to {}, {}...".format( fname_spec, fname_movie ) ) np.savetxt(fname_spec, np.array(self.spectrum)) xdata = np.arange(len(self.spectrum[0])) if self.y_reference is None: utils.save_animation(fname_movie, xdata, self.spectrum, logy=True) else: A = np.fft.rfft(y_reference) utils.save_animation( fname_movie, xdata, self.spectrum, logy=True, y_reference=np.abs(A) )
[docs] class PDEPointResampler(Callback): """Resample the training points for PDE and/or BC losses every given period. Args: period: How often to resample the training points (default is 100 iterations). pde_points: If True, resample the training points for PDE losses (default is True). bc_points: If True, resample the training points for BC losses (default is False; only supported by PyTorch and PaddlePaddle backend currently). """ def __init__(self, period=100, pde_points=True, bc_points=False): super().__init__() self.period = period self.pde_points = pde_points self.bc_points = bc_points self.num_bcs_initial = None self.epochs_since_last_resample = 0
[docs] def on_train_begin(self): self.num_bcs_initial = self.model.data.num_bcs
[docs] def on_epoch_end(self): self.epochs_since_last_resample += 1 if self.epochs_since_last_resample < self.period: return self.epochs_since_last_resample = 0 self.model.data.resample_train_points(self.pde_points, self.bc_points) if not np.array_equal(self.num_bcs_initial, self.model.data.num_bcs): print("Initial value of self.num_bcs:", self.num_bcs_initial) print("self.model.data.num_bcs:", self.model.data.num_bcs) raise ValueError( "`num_bcs` changed! Please update the loss function by `model.compile`." )