Burgers equation
Problem setup
We will solve a Burgers equation:
with the Dirichlet boundary conditions and initial conditions
The reference solution is here.
Implementation
This description goes through the implementation of a solver for the above described Burgers equation step-by-step.
First, the DeepXDE and TensorFlow (tf
) modules are imported:
import deepxde as dde
from deepxde.backend import tf
We begin by defining a computational geometry and time domain. We can use a built-in class Interval
and TimeDomain
and we combine both the domains using GeometryXTime
as follows
geom = dde.geometry.Interval(-1, 1)
timedomain = dde.geometry.TimeDomain(0, 0.99)
geomtime = dde.geometry.GeometryXTime(geom, timedomain)
Next, we express the PDE residual of the Burgers equation:
def pde(x, y):
dy_x = dde.grad.jacobian(y, x, i=0, j=0)
dy_t = dde.grad.jacobian(y, x, i=0, j=1)
dy_xx = dde.grad.hessian(y, x, i=0, j=0)
return dy_t + y * dy_x - 0.01 / np.pi * dy_xx
The first argument to pde
is 2-dimensional vector where the first component(x[:,0]
) is \(x\)-coordinate and the second componenet (x[:,1]
) is the \(t\)-coordinate. The second argument is the network output, i.e., the solution \(u(x,t)\), but here we use y
as the name of the variable.
Next, we consider the boundary/initial condition. on_boundary
is chosen here to use the whole boundary of the computational domain in considered as the boundary condition. We include the geomtime
space, time geometry created above and on_boundary
as the BCs in the DirichletBC
function of DeepXDE. We also define IC
which is the inital condition for the burgers equation and we use the computational domain, initial function, and on_initial
to specify the IC.
bc = dde.icbc.DirichletBC(geomtime, lambda x: 0, lambda _, on_boundary: on_boundary)
ic = dde.icbc.IC(geomtime, lambda x: -np.sin(np.pi * x[:, 0:1]), lambda _, on_initial: on_initial)
Now, we have specified the geometry, PDE residual, and boundary/initial condition. We then define the TimePDE
problem as
data = dde.data.TimePDE(geomtime, pde, [bc, ic],
num_domain=2540, num_boundary=80, num_initial=160)
The number 2540 is the number of training residual points sampled inside the domain, and the number 80 is the number of training points sampled on the boundary. We also include 160 initial residual points for the initial conditions.
Next, we choose the network. Here, we use a fully connected neural network of depth 4 (i.e., 3 hidden layers) and width 20:
net = dde.nn.FNN([2] + [20] * 3 + [1], "tanh", "Glorot normal")
Now, we have the PDE problem and the network. We build a Model
and choose the optimizer and learning rate:
model = dde.Model(data, net)
model.compile("adam", lr=1e-3)
We then train the model for 15000 iterations:
losshistory, train_state = model.train(iterations=15000)
After we train the network using Adam, we continue to train the network using L-BFGS to achieve a smaller loss:
model.compile("L-BFGS-B")
losshistory, train_state = model.train()
However, L-BFGS can stall out early in optimization if it is unable to find a step size satisfying the strong Wolfe conditions. In such cases, we can use the NNCG optimizer (compatible with PyTorch only) to continue reducing the loss:
dde.optimizers.set_NNCG_options(rank=50, mu=1e-1)
model.compile("NNCG")
losshistory, train_state = model.train(iterations=1000, display_every=100)
By default, NNCG does not run in this demo. You will have to uncomment the NNCG code block in the demo to have it run after Adam and L-BFGS. Note that it can take some hyperparameter tuning to get the best performance from the NNCG optimizer.
Complete code
"""Backend supported: tensorflow.compat.v1, tensorflow, pytorch, paddle"""
import deepxde as dde
import numpy as np
def gen_testdata():
data = np.load("../dataset/Burgers.npz")
t, x, exact = data["t"], data["x"], data["usol"].T
xx, tt = np.meshgrid(x, t)
X = np.vstack((np.ravel(xx), np.ravel(tt))).T
y = exact.flatten()[:, None]
return X, y
def pde(x, y):
dy_x = dde.grad.jacobian(y, x, i=0, j=0)
dy_t = dde.grad.jacobian(y, x, i=0, j=1)
dy_xx = dde.grad.hessian(y, x, i=0, j=0)
return dy_t + y * dy_x - 0.01 / np.pi * dy_xx
geom = dde.geometry.Interval(-1, 1)
timedomain = dde.geometry.TimeDomain(0, 0.99)
geomtime = dde.geometry.GeometryXTime(geom, timedomain)
bc = dde.icbc.DirichletBC(geomtime, lambda x: 0, lambda _, on_boundary: on_boundary)
ic = dde.icbc.IC(
geomtime, lambda x: -np.sin(np.pi * x[:, 0:1]), lambda _, on_initial: on_initial
)
data = dde.data.TimePDE(
geomtime, pde, [bc, ic], num_domain=2540, num_boundary=80, num_initial=160
)
net = dde.nn.FNN([2] + [20] * 3 + [1], "tanh", "Glorot normal")
model = dde.Model(data, net)
model.compile("adam", lr=1e-3)
model.train(iterations=15000)
model.compile("L-BFGS")
losshistory, train_state = model.train()
# """Backend supported: pytorch"""
# # Run NNCG after Adam and L-BFGS
# dde.optimizers.set_NNCG_options(rank=50, mu=1e-1)
# model.compile("NNCG")
# losshistory, train_state = model.train(iterations=1000, display_every=100)
dde.saveplot(losshistory, train_state, issave=True, isplot=True)
X, y_true = gen_testdata()
y_pred = model.predict(X)
f = model.predict(X, operator=pde)
print("Mean residual:", np.mean(np.absolute(f)))
print("L2 relative error:", dde.metrics.l2_relative_error(y_true, y_pred))
np.savetxt("test.dat", np.hstack((X, y_true, y_pred)))