# Source code for statinf.ml.losses

import math
import numpy as np
import pandas as pd
import jax.numpy as jnp

from ..misc import format_object

[docs]def log_stability(x, delta=10e-5):
"""Log-stability for computing loss.

:param x: Input value.
:type x: :obj:float
:param delta: Constant to move from 0 or 1, defaults to 10e-9.
:type delta: :obj:float, optional

:return: Stabilized value where :math:\\hat{x} \\in (0, 1)
:rtype: :obj:float
"""
if type(x) == float:
if x == 0:
return delta
elif x == 1:
return 1.0 - delta
else:
return x
else:
new_x = jnp.where(x == 0., delta, x)
return jnp.where(new_x == 1., 1. - delta, new_x)

[docs]def binary_cross_entropy(y_true, y_pred):
"""Binary cross-entropy.

:param y_true: Real values on which to compare.
:type y_true: :obj:numpy.array
:param y_pred: Predicted values.
:type y_pred: :obj:numpy.array

:formula: :math:loss = y_{i} \\log \\left[ \\hat{y}_{i} \\right] + (1 - y_{i}) \\log \\left[1 - \\hat{y}_{i} \\right]

:references: * Friedman, J., Hastie, T. and Tibshirani, R., 2001. The elements of statistical learning <https://web.stanford.edu/~hastie/Papers/ESLII.pdf>_. Ch. 2, pp. 24.

:return: Binary cross-entropy
:rtype: :obj:float
"""
# if tensor:
#     loss = -y_true * T.log(y_pred) - (1-y_true) * T.log(1-y_pred)
# else:
# print(y_true)
loss = jnp.sum(-y_true * jnp.log(log_stability(y_pred)) - (1 - y_true) * jnp.log(log_stability(1 - y_pred)))
# print('Loss is')
# print(loss)
return loss

[docs]def mean_squared_error(y_true, y_pred, root=False):
"""Mean Squared Error.

:param y_true: Real values on which to compare.
:type y_true: :obj:numpy.array
:param y_pred: Predicted values.
:type y_pred: :obj:numpy.array
:param root: Return Root Mean Squared Error (RMSE), defaults to False.
:type root: :obj:bool, optional

:formula: :math:loss = \\dfrac{1}{m} \\times \\sum_{i=1}^{m} (y_i - \\hat{y}_i)^2

:references: * Friedman, J., Hastie, T. and Tibshirani, R., 2001. The elements of statistical learning <https://web.stanford.edu/~hastie/Papers/ESLII.pdf>_. Ch. 2, pp. 24.

:return: Mean Squared Error or its root.
:rtype: :obj:float
"""

loss = jnp.square(y_pred - y_true)
mse = jnp.mean(loss)
if root:
return jnp.sqrt(mse)
else:
return mse

[docs]def binary_accuracy(y_true, y_pred):
"""Accuracy for binary data.

:param y_true: Real values on which to compare.
:type y_true: :obj:numpy.array
:param y_pred: Predicted values
:type y_pred: numpy.array

:return: Binary accuracy (in percent)
:rtype: :obj:float
"""
true = format_object(y_true, to_type='list', name='y_true')
pred = format_object(y_pred, to_type='list', name='y_pred')
# true = y_true if type(y_true) == list else [x[0] for x in np.asarray(y_true)]
# pred = y_pred if type(y_pred) == list else [x[0] for x in np.asarray(y_pred)]

perf = pd.DataFrame({'true': true, 'pred': pred})
return (perf.true == perf.pred).mean()