import abc
from typing import Literal
import numpy as np
[docs]
class Geometry(abc.ABC):
def __init__(self, dim, bbox, diam):
self.dim = dim
self.bbox = bbox
self.diam = min(diam, np.linalg.norm(bbox[1] - bbox[0]))
self.idstr = type(self).__name__
[docs]
@abc.abstractmethod
def inside(self, x):
"""Check if x is inside the geometry (including the boundary)."""
[docs]
@abc.abstractmethod
def on_boundary(self, x):
"""Check if x is on the geometry boundary."""
[docs]
def distance2boundary(self, x, dirn):
raise NotImplementedError(
"{}.distance2boundary to be implemented".format(self.idstr)
)
[docs]
def mindist2boundary(self, x):
raise NotImplementedError(
"{}.mindist2boundary to be implemented".format(self.idstr)
)
[docs]
def boundary_constraint_factor(
self, x, smoothness: Literal["C0", "C0+", "Cinf"] = "C0+"
):
"""Compute the hard constraint factor at x for the boundary.
This function is used for the hard-constraint methods in Physics-Informed Neural Networks (PINNs).
The hard constraint factor satisfies the following properties:
- The function is zero on the boundary and positive elsewhere.
- The function is at least continuous.
In the ansatz `boundary_constraint_factor(x) * NN(x) + boundary_condition(x)`, when `x` is on the boundary,
`boundary_constraint_factor(x)` will be zero, making the ansatz be the boundary condition, which in
turn makes the boundary condition a "hard constraint".
Args:
x: A 2D array of shape (n, dim), where `n` is the number of points and
`dim` is the dimension of the geometry. Note that `x` should be a tensor type
of backend (e.g., `tf.Tensor` or `torch.Tensor`), not a numpy array.
smoothness (string, optional): A string to specify the smoothness of the distance function,
e.g., "C0", "C0+", "Cinf". "C0" is the least smooth, "Cinf" is the most smooth.
Default is "C0+".
- C0
The distance function is continuous but may not be non-differentiable.
But the set of non-differentiable points should have measure zero,
which makes the probability of the collocation point falling in this set be zero.
- C0+
The distance function is continuous and differentiable almost everywhere. The
non-differentiable points can only appear on boundaries. If the points in `x` are
all inside or outside the geometry, the distance function is smooth.
- Cinf
The distance function is continuous and differentiable at any order on any
points. This option may result in a polynomial of HIGH order.
Returns:
A tensor of a type determined by the backend, which will have a shape of (n, 1).
Each element in the tensor corresponds to the computed distance value for the respective point in `x`.
"""
raise NotImplementedError(
"{}.boundary_constraint_factor to be implemented".format(self.idstr)
)
[docs]
def boundary_normal(self, x):
"""Compute the unit normal at x for Neumann or Robin boundary conditions."""
raise NotImplementedError(
"{}.boundary_normal to be implemented".format(self.idstr)
)
[docs]
@abc.abstractmethod
def random_points(self, n, random="pseudo"):
"""Compute the random point locations in the geometry."""
[docs]
@abc.abstractmethod
def random_boundary_points(self, n, random="pseudo"):
"""Compute the random point locations on the boundary."""
[docs]
def periodic_point(self, x, component):
"""Compute the periodic image of x for periodic boundary condition."""
raise NotImplementedError(
"{}.periodic_point to be implemented".format(self.idstr)
)
[docs]
def background_points(self, x, dirn, dist2npt, shift):
raise NotImplementedError(
"{}.background_points to be implemented".format(self.idstr)
)
[docs]
def union(self, other):
"""CSG Union."""
from . import csg
return csg.CSGUnion(self, other)
def __or__(self, other):
"""CSG Union."""
from . import csg
return csg.CSGUnion(self, other)
[docs]
def difference(self, other):
"""CSG Difference."""
from . import csg
return csg.CSGDifference(self, other)
def __sub__(self, other):
"""CSG Difference."""
from . import csg
return csg.CSGDifference(self, other)
[docs]
def intersection(self, other):
"""CSG Intersection."""
from . import csg
return csg.CSGIntersection(self, other)
def __and__(self, other):
"""CSG Intersection."""
from . import csg
return csg.CSGIntersection(self, other)