"""TensorFlow interface for TensorFlow Probability optimizers."""
# This is a standalone code to show how to use TFP in TensorFlow 1.x. Code below is
# inspired by https://gist.github.com/piyueh/712ec7d4540489aad2dcfb80f9a54993.
# But the DeepXDE interface is still under development...
# References:
# - https://www.tensorflow.org/probability/api_docs/python/tfp/optimizer/lbfgs_minimize
# - https://github.com/tensorflow/probability/blob/main/tensorflow_probability/python/optimizer/lbfgs_test.py
# - https://pychao.com/2019/11/02/optimize-tensorflow-keras-models-with-l-bfgs-from-tensorflow-probability/
# - https://gist.github.com/piyueh/712ec7d4540489aad2dcfb80f9a54993
import numpy as np
import tensorflow.compat.v1 as tf
import tensorflow_probability as tfp
from matplotlib import pyplot
tf.disable_v2_behavior()
[docs]
class LossAndFlatGradient:
"""A helper class to create a function required by tfp.optimizer.lbfgs_minimize.
Args:
build_loss: A function to build the loss function expression.
"""
def __init__(self, build_loss):
self.build_loss = build_loss
# Shapes of all trainable parameters
self.shapes = [v.get_shape().as_list() for v in tf.trainable_variables()]
self.n_tensors = len(self.shapes)
# Information for tf.dynamic_stitch and tf.dynamic_partition later
count = 0
self.indices = [] # stitch indices
self.partitions = [] # partition indices
for i, shape in enumerate(self.shapes):
n = np.prod(shape)
self.indices.append(
tf.reshape(tf.range(count, count + n, dtype=tf.int32), shape)
)
self.partitions.extend([i] * n)
count += n
self.partitions = tf.constant(self.partitions)
def __call__(self, weights_1d):
"""A function that can be used by tfp.optimizer.lbfgs_minimize.
Args:
weights_1d: a 1D tf.Tensor.
Returns:
A scalar loss and the gradients w.r.t. the `weights_1d`.
"""
# Update the weights
update_ops = self.set_flat_weights(weights_1d)
with tf.control_dependencies([update_ops]):
# Calculate the loss
loss = self.build_loss()
# Calculate gradients and convert to 1D tf.Tensor
grads = tf.gradients(loss, tf.trainable_variables())
grads = tf.dynamic_stitch(self.indices, grads)
return loss, grads
[docs]
def set_flat_weights(self, weights_1d):
"""Sets the weights with a 1D tf.Tensor.
Args:
weights_1d: a 1D tf.Tensor representing the trainable variables.
Returns:
A ``tf.Operation``.
"""
weights = tf.dynamic_partition(weights_1d, self.partitions, self.n_tensors)
variables = tf.trainable_variables()
update_ops = []
for i, (shape, param) in enumerate(zip(self.shapes, weights)):
update_ops.append(variables[i].assign(tf.reshape(param, shape)))
return tf.group(*update_ops)
[docs]
def plot_helper(inputs, outputs, title, fname):
"""Plot helper"""
pyplot.figure()
pyplot.tricontourf(inputs[:, 0], inputs[:, 1], outputs.flatten(), 100)
pyplot.xlabel("x")
pyplot.ylabel("y")
pyplot.title(title)
pyplot.colorbar()
pyplot.savefig(fname)
if __name__ == "__main__":
sess = tf.keras.backend.get_session()
# use float64 by default
# tf.keras.backend.set_floatx("float64")
# prepare training data
x_1d = np.linspace(-1.0, 1.0, 11, dtype=np.float32)
x1, x2 = np.meshgrid(x_1d, x_1d)
inps = np.stack((x1.flatten(), x2.flatten()), 1)
outs = np.reshape(inps[:, 0] ** 2 + inps[:, 1] ** 2, (x_1d.size ** 2, 1))
train_x = tf.placeholder(tf.float32, [None, 2])
train_y = tf.placeholder(tf.float32, [None, 1])
# prepare prediction model, loss function, and the function passed to L-BFGS solver
loss_fun = tf.keras.losses.MeanSquaredError()
model = tf.keras.Sequential(
[
tf.keras.Input(shape=[2]),
tf.keras.layers.Dense(64, "tanh"),
tf.keras.layers.Dense(64, "tanh"),
tf.keras.layers.Dense(1, None),
]
)
func = LossAndFlatGradient(lambda: loss_fun(model(train_x), train_y))
# W = tf.Variable(tf.random_normal([2, 1]))
# b = tf.Variable(tf.zeros(1))
# y = tf.matmul(train_x, W) + b
# loss = loss_fun(y, train_y)
# def build_loss():
# return loss + 0
# func = LossAndFlatGradient(build_loss)
# convert initial model parameters to a 1D tf.Tensor
params = tf.dynamic_stitch(func.indices, tf.trainable_variables())
# train the model with L-BFGS solver
lbfgs_op = tfp.optimizer.lbfgs_minimize(
value_and_gradients_function=func, initial_position=params, max_iterations=100
)
sess.run(tf.global_variables_initializer())
for i in range(5):
results = sess.run(lbfgs_op, feed_dict={train_x: inps, train_y: outs})
# after training, the final optimized parameters are still in results.position
# so we have to manually put them back to the model
sess.run(func.set_flat_weights(results.position))
print(
i,
"converged:",
results.converged,
"failed:",
results.failed,
"num_iterations:",
results.num_iterations,
"loss:",
results.objective_value,
)
if results.converged or results.failed:
break
# do some prediction
pred_outs = sess.run(model(inps))
# pred_outs = sess.run(y, feed_dict={train_x: inps})
err = np.abs(pred_outs - outs)
print("\nL2-error norm: {}".format(np.linalg.norm(err) / np.sqrt(11)))
# plot figures
plot_helper(inps, outs, "Exact solution", "ext_soln.png")
plot_helper(inps, pred_outs, "Predicted solution", "pred_soln.png")
plot_helper(inps, err, "Absolute error", "abs_err.png")
pyplot.show()