"""pytorch backend implementation"""
from packaging.version import Version
import torch
if Version(torch.__version__) < Version("2.0.0"):
raise RuntimeError("DeepXDE requires PyTorch>=2.0.0.")
# To write device-agnostic (CPU or GPU) code, a common pattern is to first determine
# torch.device and then use it for all the tensors.
# https://pytorch.org/docs/stable/notes/cuda.html
# >>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# >>> tensor.to(device=device)
# But, taking care of all tensors requires a lot of work.
# An alternative way is to use GPU by default if GPU is available, which is similar to
# TensorFlow.
if torch.cuda.is_available():
if Version(torch.__version__) >= Version("2.1.0"):
torch.set_default_device("cuda")
else:
torch.set_default_tensor_type(torch.cuda.FloatTensor)
lib = torch
[docs]
def data_type_dict():
return {
"float16": torch.float16,
"float32": torch.float32,
"float64": torch.float64,
"uint8": torch.uint8,
"int8": torch.int8,
"int16": torch.int16,
"int32": torch.int32,
"int64": torch.int64,
"bool": torch.bool,
}
[docs]
def is_gpu_available():
return torch.cuda.is_available()
[docs]
def is_tensor(obj):
return torch.is_tensor(obj)
[docs]
def shape(input_tensor):
return list(input_tensor.shape)
[docs]
def size(tensor):
return torch.numel(tensor)
[docs]
def ndim(input_tensor):
return input_tensor.dim()
[docs]
def transpose(tensor, axes=None):
if axes is None:
axes = tuple(range(tensor.dim())[::-1])
return torch.permute(tensor, axes)
[docs]
def reshape(tensor, shape):
return torch.reshape(tensor, shape)
[docs]
def Variable(initial_value, dtype=None):
return torch.tensor(initial_value, dtype=dtype, requires_grad=True)
[docs]
def as_tensor(data, dtype=None):
if isinstance(data, torch.Tensor):
if dtype is None or data.dtype == dtype:
return data
return data.type(dtype=dtype)
return torch.as_tensor(data, dtype=dtype)
[docs]
def sparse_tensor(indices, values, shape):
return torch.sparse_coo_tensor(list(zip(*indices)), values, shape, requires_grad=True)
[docs]
def from_numpy(np_array):
# Both torch.from_numpy and torch.as_tensor work without memory copy.
# https://discuss.pytorch.org/t/from-numpy-vs-as-tensor/79932
# https://stackoverflow.com/questions/48482787/pytorch-memory-model-torch-from-numpy-vs-torch-tensor
# But torch.from_numpy cannot handle device.
return torch.as_tensor(np_array)
[docs]
def to_numpy(input_tensor):
return input_tensor.detach().cpu().numpy()
[docs]
def concat(values, axis):
return torch.cat(values, axis)
[docs]
def stack(values, axis):
return torch.stack(values, axis)
[docs]
def expand_dims(tensor, axis):
return torch.unsqueeze(tensor, axis)
[docs]
def reverse(tensor, axis):
return torch.flip(tensor, axis)
[docs]
def roll(tensor, shift, axis):
return torch.roll(tensor, shift, axis)
[docs]
def lgamma(x):
return torch.lgamma(x)
[docs]
def elu(x):
return torch.nn.functional.elu(x)
[docs]
def relu(x):
return torch.nn.functional.relu(x)
[docs]
def gelu(x):
return torch.nn.functional.gelu(x)
[docs]
def selu(x):
return torch.nn.functional.selu(x)
[docs]
def sigmoid(x):
return torch.nn.functional.sigmoid(x)
[docs]
def silu(x):
return torch.nn.functional.silu(x)
[docs]
def sin(x):
return torch.sin(x)
[docs]
def cos(x):
return torch.cos(x)
[docs]
def exp(x):
return torch.exp(x)
[docs]
def square(x):
return torch.square(x)
# pylint: disable=redefined-builtin
[docs]
def abs(x):
return torch.abs(x)
[docs]
def minimum(x, y):
return torch.minimum(x, y)
[docs]
def tanh(x):
return torch.tanh(x)
[docs]
def pow(x, y):
return torch.pow(x, y)
[docs]
def mean(input_tensor, dim, keepdims=False):
return torch.mean(input_tensor, dim, keepdim=keepdims)
[docs]
def reduce_mean(input_tensor):
return torch.mean(input_tensor)
[docs]
def sum(input_tensor, dim, keepdims=False):
return torch.sum(input_tensor, dim, keepdim=keepdims)
[docs]
def reduce_sum(input_tensor):
return torch.sum(input_tensor)
[docs]
def prod(input_tensor, dim, keepdims=False):
return torch.prod(input_tensor, dim, keepdim=keepdims)
[docs]
def reduce_prod(input_tensor):
return torch.prod(input_tensor)
# pylint: disable=redefined-builtin
[docs]
def min(input_tensor, dim, keepdims=False):
return torch.amin(input_tensor, dim, keepdim=keepdims)
[docs]
def reduce_min(input_tensor):
return torch.min(input_tensor)
# pylint: disable=redefined-builtin
[docs]
def max(input_tensor, dim, keepdims=False):
return torch.amax(input_tensor, dim, keepdim=keepdims)
[docs]
def reduce_max(input_tensor):
return torch.max(input_tensor)
[docs]
def norm(tensor, ord=None, axis=None, keepdims=False):
return torch.linalg.norm(tensor, ord=ord, dim=axis, keepdim=keepdims)
[docs]
def zeros(shape, dtype):
return torch.zeros(shape, dtype=dtype)
[docs]
def zeros_like(input_tensor):
return torch.zeros_like(input_tensor)
[docs]
def matmul(x, y):
return torch.mm(x, y)
[docs]
def sparse_dense_matmul(x, y):
return torch.sparse.mm(x, y)