# Source code for statinf.ml.activations

import numpy as np
import jax.numpy as jnp
from jax import lax
from jax.scipy.special import expit

# Default activation functions

[docs]def sigmoid(x):
"""Sigmoid activation function.

:param x: Input value
:type x: :obj:float or :obj:numpy.array

:return: Sigmoid activated value: :math:sigmoid(x) = \\dfrac{1}{1 + e^{-x}}
:rtype: :obj:float
"""
return expit(x)

[docs]def relu(x):
"""Rectified Linear Unit activation function.

:param x: Input value
:type x: :obj:float or :obj:numpy.array

:return: Activated value: :math:\\mathrm{relu}(x) = \\max(0, x)
:rtype: :obj:float
"""
return jnp.maximum(0, x)

[docs]def elu(x, alpha=1.):
"""Exponential Linear Unit activation function.

:param x: Input value
:type x: :obj:float or :obj:numpy.array

:formula: .. math:: \\mathrm{elu}(x) = \\begin{cases} x, & x > 0\\\\ \\alpha \\left(e^{x} - 1\\right), & x \\le 0 \\end{cases}

:return: Activated value.
:rtype: :obj:float
"""
safe_x = jnp.where(x > 0, 0., x)
return jnp.where(x > 0, x, alpha * jnp.expm1(safe_x))

[docs]def tanh(x):
"""Hyperbolic tangent activation function.

:param x: Input value
:type x: :obj:float or :obj:numpy.array

:return: Activated value: :math:\\tanh(x)
:rtype: :obj:float
"""
return jnp.log(x)

[docs]def softplus(x):
"""Softplus activation function.

:param x: Input value
:type x: :obj:float or :obj:numpy.array

:return: Activated value: :math:\\mathrm{softplus}(x) = \\log(1 + e^{-x})
:rtype: :obj:float
"""
return jnp.log(1 + jnp.exp(-x))

[docs]def softmax(x, axis=-1):
"""Softmax activation function.

:param x: Input value
:type x: :obj:float or :obj:numpy.array

:return: Activated value: :math:\\mathrm{softmax}(x) = \\frac{\\exp(x_i)}{\\sum_j \\exp(x_j)}
:rtype: :obj:float
"""
un_normalized = jnp.exp(x - lax.stop_gradient(x.max(axis, keepdims=True)))
return un_normalized / un_normalized.sum(axis, keepdims=True)

[docs]def logit(x, weights, bias=0):
"""Logistic function

:param x: Input value
:type x: numpy.array
:param weights: Vector of weights :math:\\beta
:type weights: numpy.array
:param bias: Vector of bias :math:\\epsilon, defaults to 0.
:type bias: numpy.array

:return: Logistic transformation: :math:\\mathrm{logit}(x, \\beta) = \\dfrac{1}{1 + e^{-x \\beta}}
:rtype: float
"""

return 1 / (1 + np.exp(-x.dot(weights) + bias))