"""jax backend implementation"""
import jax
import jax.numpy as jnp
import numpy as np
lib = jax
[docs]
def data_type_dict():
return {
"float16": jnp.float16,
"float32": jnp.float32,
"float64": jnp.float64,
"uint8": jnp.uint8,
"int8": jnp.int8,
"int16": jnp.int16,
"int32": jnp.int32,
"int64": jnp.int64,
"bool": jnp.bool_,
}
[docs]
def is_tensor(obj):
return isinstance(obj, jnp.ndarray)
[docs]
def shape(input_array):
return input_array.shape
[docs]
def ndim(input_array):
return input_array.ndim
[docs]
def transpose(tensor, axes=None):
return jnp.transpose(tensor, axes=axes)
[docs]
def reshape(tensor, shape):
return jnp.reshape(tensor, shape)
[docs]
class Variable:
def __init__(self, initial_value, dtype=None):
self._value = jnp.array(initial_value, dtype=dtype)
@property
def value(self):
return self._value
@value.setter
def value(self, new_value):
self._value = new_value
[docs]
def as_tensor(data, dtype=None):
if isinstance(data, jnp.ndarray):
if dtype is None or data.dtype == dtype:
return data
return data.astype(dtype)
return jnp.asarray(data, dtype=dtype)
[docs]
def from_numpy(np_array):
return jnp.asarray(np_array)
[docs]
def to_numpy(input_tensor):
return np.asarray(input_tensor)
[docs]
def concat(values, axis):
return jnp.concatenate(values, axis=axis)
[docs]
def stack(values, axis):
return jnp.stack(values, axis=axis)
[docs]
def elu(x):
return jax.nn.elu(x)
[docs]
def relu(x):
return jax.nn.relu(x)
[docs]
def selu(x):
return jax.nn.selu(x)
[docs]
def sigmoid(x):
return jax.nn.sigmoid(x)
[docs]
def silu(x):
return jax.nn.silu(x)
[docs]
def sin(x):
return jnp.sin(x)
[docs]
def cos(x):
return jnp.cos(x)
[docs]
def square(x):
return jnp.square(x)
# pylint: disable=redefined-builtin
[docs]
def abs(x):
return jnp.abs(x)
[docs]
def minimum(x, y):
return jnp.minimum(x, y)
[docs]
def tanh(x):
return jnp.tanh(x)
[docs]
def mean(input_tensor, dim, keepdims=False):
return jnp.mean(input_tensor, axis=dim, keepdims=keepdims)
[docs]
def reduce_mean(input_tensor):
return jnp.mean(input_tensor)
[docs]
def sum(input_tensor, dim, keepdims=False):
return jnp.sum(input_tensor, axis=dim, keepdims=keepdims)
[docs]
def reduce_sum(input_tensor):
return jnp.sum(input_tensor)
[docs]
def prod(input_tensor, dim, keepdims=False):
return jnp.prod(input_tensor, axis=dim, keepdims=keepdims)
[docs]
def reduce_prod(input_tensor):
return jnp.prod(input_tensor)
# pylint: disable=redefined-builtin
[docs]
def min(input_tensor, dim, keepdims=False):
return jnp.min(input_tensor, axis=dim, keepdims=keepdims)
[docs]
def reduce_min(input_tensor):
return jnp.min(input_tensor)
# pylint: disable=redefined-builtin
[docs]
def max(input_tensor, dim, keepdims=False):
return jnp.max(input_tensor, axis=dim, keepdims=keepdims)
[docs]
def reduce_max(input_tensor):
return jnp.max(input_tensor)
[docs]
def zeros(shape, dtype):
return jnp.zeros(shape, dtype=dtype)
[docs]
def zeros_like(input_tensor):
return jnp.zeros_like(input_tensor)