Source code for deepxde.backend.jax.tensor

"""jax backend implementation"""
import jax
import jax.numpy as jnp
import numpy as np


lib = jax


[docs] def data_type_dict(): return { "float16": jnp.float16, "float32": jnp.float32, "float64": jnp.float64, "uint8": jnp.uint8, "int8": jnp.int8, "int16": jnp.int16, "int32": jnp.int32, "int64": jnp.int64, "bool": jnp.bool_, }
[docs] def is_tensor(obj): return isinstance(obj, jnp.ndarray)
[docs] def shape(input_array): return input_array.shape
[docs] def ndim(input_array): return input_array.ndim
[docs] def transpose(tensor, axes=None): return jnp.transpose(tensor, axes=axes)
[docs] def reshape(tensor, shape): return jnp.reshape(tensor, shape)
[docs] class Variable: def __init__(self, initial_value, dtype=None): self._value = jnp.array(initial_value, dtype=dtype) @property def value(self): return self._value @value.setter def value(self, new_value): self._value = new_value
[docs] def as_tensor(data, dtype=None): if isinstance(data, jnp.ndarray): if dtype is None or data.dtype == dtype: return data return data.astype(dtype) return jnp.asarray(data, dtype=dtype)
[docs] def from_numpy(np_array): return jnp.asarray(np_array)
[docs] def to_numpy(input_tensor): return np.asarray(input_tensor)
[docs] def concat(values, axis): return jnp.concatenate(values, axis=axis)
[docs] def stack(values, axis): return jnp.stack(values, axis=axis)
[docs] def elu(x): return jax.nn.elu(x)
[docs] def relu(x): return jax.nn.relu(x)
[docs] def selu(x): return jax.nn.selu(x)
[docs] def sigmoid(x): return jax.nn.sigmoid(x)
[docs] def silu(x): return jax.nn.silu(x)
[docs] def sin(x): return jnp.sin(x)
[docs] def cos(x): return jnp.cos(x)
[docs] def square(x): return jnp.square(x)
# pylint: disable=redefined-builtin
[docs] def abs(x): return jnp.abs(x)
[docs] def minimum(x, y): return jnp.minimum(x, y)
[docs] def tanh(x): return jnp.tanh(x)
[docs] def mean(input_tensor, dim, keepdims=False): return jnp.mean(input_tensor, axis=dim, keepdims=keepdims)
[docs] def reduce_mean(input_tensor): return jnp.mean(input_tensor)
[docs] def sum(input_tensor, dim, keepdims=False): return jnp.sum(input_tensor, axis=dim, keepdims=keepdims)
[docs] def reduce_sum(input_tensor): return jnp.sum(input_tensor)
[docs] def prod(input_tensor, dim, keepdims=False): return jnp.prod(input_tensor, axis=dim, keepdims=keepdims)
[docs] def reduce_prod(input_tensor): return jnp.prod(input_tensor)
# pylint: disable=redefined-builtin
[docs] def min(input_tensor, dim, keepdims=False): return jnp.min(input_tensor, axis=dim, keepdims=keepdims)
[docs] def reduce_min(input_tensor): return jnp.min(input_tensor)
# pylint: disable=redefined-builtin
[docs] def max(input_tensor, dim, keepdims=False): return jnp.max(input_tensor, axis=dim, keepdims=keepdims)
[docs] def reduce_max(input_tensor): return jnp.max(input_tensor)
[docs] def zeros(shape, dtype): return jnp.zeros(shape, dtype=dtype)
[docs] def zeros_like(input_tensor): return jnp.zeros_like(input_tensor)