Source code for deepxde.optimizers.tensorflow_compat_v1.tfp_optimizer

"""TensorFlow interface for TensorFlow Probability optimizers."""

# This is a standalone code to show how to use TFP in TensorFlow 1.x. Code below is
# inspired by https://gist.github.com/piyueh/712ec7d4540489aad2dcfb80f9a54993.
# But the DeepXDE interface is still under development...
# References:
# - https://www.tensorflow.org/probability/api_docs/python/tfp/optimizer/lbfgs_minimize
# - https://github.com/tensorflow/probability/blob/main/tensorflow_probability/python/optimizer/lbfgs_test.py
# - https://pychao.com/2019/11/02/optimize-tensorflow-keras-models-with-l-bfgs-from-tensorflow-probability/
# - https://gist.github.com/piyueh/712ec7d4540489aad2dcfb80f9a54993


import numpy as np
import tensorflow.compat.v1 as tf
import tensorflow_probability as tfp
from matplotlib import pyplot

tf.disable_v2_behavior()


[docs] class LossAndFlatGradient: """A helper class to create a function required by tfp.optimizer.lbfgs_minimize. Args: build_loss: A function to build the loss function expression. """ def __init__(self, build_loss): self.build_loss = build_loss # Shapes of all trainable parameters self.shapes = [v.get_shape().as_list() for v in tf.trainable_variables()] self.n_tensors = len(self.shapes) # Information for tf.dynamic_stitch and tf.dynamic_partition later count = 0 self.indices = [] # stitch indices self.partitions = [] # partition indices for i, shape in enumerate(self.shapes): n = np.prod(shape) self.indices.append( tf.reshape(tf.range(count, count + n, dtype=tf.int32), shape) ) self.partitions.extend([i] * n) count += n self.partitions = tf.constant(self.partitions) def __call__(self, weights_1d): """A function that can be used by tfp.optimizer.lbfgs_minimize. Args: weights_1d: a 1D tf.Tensor. Returns: A scalar loss and the gradients w.r.t. the `weights_1d`. """ # Update the weights update_ops = self.set_flat_weights(weights_1d) with tf.control_dependencies([update_ops]): # Calculate the loss loss = self.build_loss() # Calculate gradients and convert to 1D tf.Tensor grads = tf.gradients(loss, tf.trainable_variables()) grads = tf.dynamic_stitch(self.indices, grads) return loss, grads
[docs] def set_flat_weights(self, weights_1d): """Sets the weights with a 1D tf.Tensor. Args: weights_1d: a 1D tf.Tensor representing the trainable variables. Returns: A ``tf.Operation``. """ weights = tf.dynamic_partition(weights_1d, self.partitions, self.n_tensors) variables = tf.trainable_variables() update_ops = [] for i, (shape, param) in enumerate(zip(self.shapes, weights)): update_ops.append(variables[i].assign(tf.reshape(param, shape))) return tf.group(*update_ops)
[docs] def plot_helper(inputs, outputs, title, fname): """Plot helper""" pyplot.figure() pyplot.tricontourf(inputs[:, 0], inputs[:, 1], outputs.flatten(), 100) pyplot.xlabel("x") pyplot.ylabel("y") pyplot.title(title) pyplot.colorbar() pyplot.savefig(fname)
if __name__ == "__main__": sess = tf.keras.backend.get_session() # use float64 by default # tf.keras.backend.set_floatx("float64") # prepare training data x_1d = np.linspace(-1.0, 1.0, 11, dtype=np.float32) x1, x2 = np.meshgrid(x_1d, x_1d) inps = np.stack((x1.flatten(), x2.flatten()), 1) outs = np.reshape(inps[:, 0] ** 2 + inps[:, 1] ** 2, (x_1d.size ** 2, 1)) train_x = tf.placeholder(tf.float32, [None, 2]) train_y = tf.placeholder(tf.float32, [None, 1]) # prepare prediction model, loss function, and the function passed to L-BFGS solver loss_fun = tf.keras.losses.MeanSquaredError() model = tf.keras.Sequential( [ tf.keras.Input(shape=[2]), tf.keras.layers.Dense(64, "tanh"), tf.keras.layers.Dense(64, "tanh"), tf.keras.layers.Dense(1, None), ] ) func = LossAndFlatGradient(lambda: loss_fun(model(train_x), train_y)) # W = tf.Variable(tf.random_normal([2, 1])) # b = tf.Variable(tf.zeros(1)) # y = tf.matmul(train_x, W) + b # loss = loss_fun(y, train_y) # def build_loss(): # return loss + 0 # func = LossAndFlatGradient(build_loss) # convert initial model parameters to a 1D tf.Tensor params = tf.dynamic_stitch(func.indices, tf.trainable_variables()) # train the model with L-BFGS solver lbfgs_op = tfp.optimizer.lbfgs_minimize( value_and_gradients_function=func, initial_position=params, max_iterations=100 ) sess.run(tf.global_variables_initializer()) for i in range(5): results = sess.run(lbfgs_op, feed_dict={train_x: inps, train_y: outs}) # after training, the final optimized parameters are still in results.position # so we have to manually put them back to the model sess.run(func.set_flat_weights(results.position)) print( i, "converged:", results.converged, "failed:", results.failed, "num_iterations:", results.num_iterations, "loss:", results.objective_value, ) if results.converged or results.failed: break # do some prediction pred_outs = sess.run(model(inps)) # pred_outs = sess.run(y, feed_dict={train_x: inps}) err = np.abs(pred_outs - outs) print("\nL2-error norm: {}".format(np.linalg.norm(err) / np.sqrt(11))) # plot figures plot_helper(inps, outs, "Exact solution", "ext_soln.png") plot_helper(inps, pred_outs, "Predicted solution", "pred_soln.png") plot_helper(inps, err, "Absolute error", "abs_err.png") pyplot.show()