Source code for deepxde.losses

from . import backend as bkd
from . import config
from .backend import tf


[docs] def mean_absolute_error(y_true, y_pred): # TODO: pytorch return tf.keras.losses.MeanAbsoluteError()(y_true, y_pred)
[docs] def mean_absolute_percentage_error(y_true, y_pred): # TODO: pytorch return tf.keras.losses.MeanAbsolutePercentageError()(y_true, y_pred)
[docs] def mean_squared_error(y_true, y_pred): # Warning: # - Do not use ``tf.losses.mean_squared_error``, which casts `y_true` and `y_pred` to ``float32``. # - Do not use ``tf.keras.losses.MSE``, which computes the mean value over the last dimension. # - Do not use ``tf.keras.losses.MeanSquaredError()``, which casts loss to ``float32`` # when calling ``compute_weighted_loss()`` calling ``scale_losses_by_sample_weight()``, # although it finally casts loss back to the original type. return bkd.reduce_mean(bkd.square(y_true - y_pred))
[docs] def mean_l2_relative_error(y_true, y_pred): return bkd.reduce_mean(bkd.norm(y_true - y_pred, axis=1) / bkd.norm(y_true, axis=1))
[docs] def softmax_cross_entropy(y_true, y_pred): # TODO: pytorch return tf.keras.losses.CategoricalCrossentropy(from_logits=True)(y_true, y_pred)
[docs] def zero(*_): # TODO: pytorch return tf.constant(0, dtype=config.real(tf))
LOSS_DICT = { "mean absolute error": mean_absolute_error, "MAE": mean_absolute_error, "mae": mean_absolute_error, "mean squared error": mean_squared_error, "MSE": mean_squared_error, "mse": mean_squared_error, "mean absolute percentage error": mean_absolute_percentage_error, "MAPE": mean_absolute_percentage_error, "mape": mean_absolute_percentage_error, "mean l2 relative error": mean_l2_relative_error, "softmax cross entropy": softmax_cross_entropy, "zero": zero, }
[docs] def get(identifier): """Retrieves a loss function. Args: identifier: A loss identifier. String name of a loss function, or a loss function. Returns: A loss function. """ if isinstance(identifier, (list, tuple)): return list(map(get, identifier)) if isinstance(identifier, str): return LOSS_DICT[identifier] if callable(identifier): return identifier raise ValueError("Could not interpret loss function identifier:", identifier)