Source code for deepxde.backend.pytorch.tensor

"""pytorch backend implementation"""
from packaging.version import Version

import torch


if Version(torch.__version__) < Version("2.0.0"):
    raise RuntimeError("DeepXDE requires PyTorch>=2.0.0.")

# To write device-agnostic (CPU or GPU) code, a common pattern is to first determine
# torch.device and then use it for all the tensors.
# https://pytorch.org/docs/stable/notes/cuda.html
# >>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# >>> tensor.to(device=device)
# But, taking care of all tensors requires a lot of work.
# An alternative way is to use GPU by default if GPU is available, which is similar to
# TensorFlow.
if torch.cuda.is_available():
    if Version(torch.__version__) >= Version("2.1.0"):
        torch.set_default_device("cuda")
    else:
        torch.set_default_tensor_type(torch.cuda.FloatTensor)


lib = torch


[docs] def data_type_dict(): return { "float16": torch.float16, "float32": torch.float32, "float64": torch.float64, "uint8": torch.uint8, "int8": torch.int8, "int16": torch.int16, "int32": torch.int32, "int64": torch.int64, "bool": torch.bool, }
[docs] def is_gpu_available(): return torch.cuda.is_available()
[docs] def is_tensor(obj): return torch.is_tensor(obj)
[docs] def shape(input_tensor): return list(input_tensor.shape)
[docs] def size(tensor): return torch.numel(tensor)
[docs] def ndim(input_tensor): return input_tensor.dim()
[docs] def transpose(tensor, axes=None): if axes is None: axes = tuple(range(tensor.dim())[::-1]) return torch.permute(tensor, axes)
[docs] def reshape(tensor, shape): return torch.reshape(tensor, shape)
[docs] def Variable(initial_value, dtype=None): return torch.tensor(initial_value, dtype=dtype, requires_grad=True)
[docs] def as_tensor(data, dtype=None): if isinstance(data, torch.Tensor): if dtype is None or data.dtype == dtype: return data return data.type(dtype=dtype) return torch.as_tensor(data, dtype=dtype)
[docs] def sparse_tensor(indices, values, shape): return torch.sparse_coo_tensor(list(zip(*indices)), values, shape, requires_grad=True)
[docs] def from_numpy(np_array): # Both torch.from_numpy and torch.as_tensor work without memory copy. # https://discuss.pytorch.org/t/from-numpy-vs-as-tensor/79932 # https://stackoverflow.com/questions/48482787/pytorch-memory-model-torch-from-numpy-vs-torch-tensor # But torch.from_numpy cannot handle device. return torch.as_tensor(np_array)
[docs] def to_numpy(input_tensor): return input_tensor.detach().cpu().numpy()
[docs] def concat(values, axis): return torch.cat(values, axis)
[docs] def stack(values, axis): return torch.stack(values, axis)
[docs] def expand_dims(tensor, axis): return torch.unsqueeze(tensor, axis)
[docs] def reverse(tensor, axis): return torch.flip(tensor, axis)
[docs] def roll(tensor, shift, axis): return torch.roll(tensor, shift, axis)
[docs] def lgamma(x): return torch.lgamma(x)
[docs] def elu(x): return torch.nn.functional.elu(x)
[docs] def relu(x): return torch.nn.functional.relu(x)
[docs] def gelu(x): return torch.nn.functional.gelu(x)
[docs] def selu(x): return torch.nn.functional.selu(x)
[docs] def sigmoid(x): return torch.nn.functional.sigmoid(x)
[docs] def silu(x): return torch.nn.functional.silu(x)
[docs] def sin(x): return torch.sin(x)
[docs] def cos(x): return torch.cos(x)
[docs] def exp(x): return torch.exp(x)
[docs] def square(x): return torch.square(x)
# pylint: disable=redefined-builtin
[docs] def abs(x): return torch.abs(x)
[docs] def minimum(x, y): return torch.minimum(x, y)
[docs] def tanh(x): return torch.tanh(x)
[docs] def pow(x, y): return torch.pow(x, y)
[docs] def mean(input_tensor, dim, keepdims=False): return torch.mean(input_tensor, dim, keepdim=keepdims)
[docs] def reduce_mean(input_tensor): return torch.mean(input_tensor)
[docs] def sum(input_tensor, dim, keepdims=False): return torch.sum(input_tensor, dim, keepdim=keepdims)
[docs] def reduce_sum(input_tensor): return torch.sum(input_tensor)
[docs] def prod(input_tensor, dim, keepdims=False): return torch.prod(input_tensor, dim, keepdim=keepdims)
[docs] def reduce_prod(input_tensor): return torch.prod(input_tensor)
# pylint: disable=redefined-builtin
[docs] def min(input_tensor, dim, keepdims=False): return torch.amin(input_tensor, dim, keepdim=keepdims)
[docs] def reduce_min(input_tensor): return torch.min(input_tensor)
# pylint: disable=redefined-builtin
[docs] def max(input_tensor, dim, keepdims=False): return torch.amax(input_tensor, dim, keepdim=keepdims)
[docs] def reduce_max(input_tensor): return torch.max(input_tensor)
[docs] def norm(tensor, ord=None, axis=None, keepdims=False): return torch.linalg.norm(tensor, ord=ord, dim=axis, keepdim=keepdims)
[docs] def zeros(shape, dtype): return torch.zeros(shape, dtype=dtype)
[docs] def zeros_like(input_tensor): return torch.zeros_like(input_tensor)
[docs] def matmul(x, y): return torch.mm(x, y)
[docs] def sparse_dense_matmul(x, y): return torch.sparse.mm(x, y)