from jax import config
config.update('jax_enable_x64', True)
from jax import Array
import jax.numpy as jnp
from jax.scipy.special import gammaln, xlogy
from jax.scipy import special
from jax.scipy.stats import norm
import numpy as np
import scipy.stats as osp_stats
import logging
log = logging.getLogger(__name__)
class _BasicPoisson:
def __init__(self, rate):
self.rate = rate
def sample(self, sample_shape):
# TODO: Support other dtypes
return jnp.asarray(
osp_stats.poisson(self.rate).rvs(size=sample_shape + self.rate.shape),
dtype=jnp.float64,
)
def log_prob(self, value):
tensorlib = jax_backend()
return tensorlib.poisson_logpdf(value, self.rate)
class _BasicNormal:
def __init__(self, loc, scale):
self.loc = loc
self.scale = scale
def sample(self, sample_shape):
# TODO: Support other dtypes
return jnp.asarray(
osp_stats.norm(self.loc, self.scale).rvs(
size=sample_shape + self.loc.shape
),
dtype=jnp.float64,
)
def log_prob(self, value):
tensorlib = jax_backend()
return tensorlib.normal_logpdf(value, self.loc, self.scale)
[docs]
class jax_backend:
"""JAX backend for pyhf"""
__slots__ = ['name', 'precision', 'dtypemap', 'default_do_grad']
#: The array type for jax
array_type = Array
#: The array content type for jax
array_subtype = Array
[docs]
def __init__(self, **kwargs):
self.name = 'jax'
self.precision = kwargs.get('precision', '64b')
self.dtypemap = {
'float': jnp.float64 if self.precision == '64b' else jnp.float32,
'int': jnp.int64 if self.precision == '64b' else jnp.int32,
'bool': jnp.bool_,
}
self.default_do_grad = True
[docs]
def _setup(self):
"""
Run any global setups for the jax lib.
"""
[docs]
def clip(self, tensor_in, min_value, max_value):
"""
Clips (limits) the tensor values to be within a specified min and max.
Example:
>>> import pyhf
>>> pyhf.set_backend("jax")
>>> a = pyhf.tensorlib.astensor([-2, -1, 0, 1, 2])
>>> pyhf.tensorlib.clip(a, -1, 1)
Array([-1., -1., 0., 1., 1.], dtype=float64)
Args:
tensor_in (:obj:`tensor`): The input tensor object
min_value (:obj:`scalar` or :obj:`tensor` or :obj:`None`): The minimum value to be clipped to
max_value (:obj:`scalar` or :obj:`tensor` or :obj:`None`): The maximum value to be clipped to
Returns:
JAX ndarray: A clipped `tensor`
"""
return jnp.clip(tensor_in, min_value, max_value)
[docs]
def erf(self, tensor_in):
"""
The error function of complex argument.
Example:
>>> import pyhf
>>> pyhf.set_backend("jax")
>>> a = pyhf.tensorlib.astensor([-2., -1., 0., 1., 2.])
>>> pyhf.tensorlib.erf(a)
Array([-0.99532227, -0.84270079, 0. , 0.84270079, 0.99532227], dtype=float64)
Args:
tensor_in (:obj:`tensor`): The input tensor object
Returns:
JAX ndarray: The values of the error function at the given points.
"""
return special.erf(tensor_in)
[docs]
def erfinv(self, tensor_in):
"""
The inverse of the error function of complex argument.
Example:
>>> import pyhf
>>> pyhf.set_backend("jax")
>>> a = pyhf.tensorlib.astensor([-2., -1., 0., 1., 2.])
>>> pyhf.tensorlib.erfinv(pyhf.tensorlib.erf(a))
Array([-2., -1., 0., 1., 2.], dtype=float64)
Args:
tensor_in (:obj:`tensor`): The input tensor object
Returns:
JAX ndarray: The values of the inverse of the error function at the given points.
"""
return special.erfinv(tensor_in)
[docs]
def tile(self, tensor_in, repeats):
"""
Repeat tensor data along a specific dimension
Example:
>>> import pyhf
>>> pyhf.set_backend("jax")
>>> a = pyhf.tensorlib.astensor([[1.0], [2.0]])
>>> pyhf.tensorlib.tile(a, (1, 2))
Array([[1., 1.],
[2., 2.]], dtype=float64)
Args:
tensor_in (:obj:`tensor`): The tensor to be repeated
repeats (:obj:`tensor`): The tuple of multipliers for each dimension
Returns:
JAX ndarray: The tensor with repeated axes
"""
return jnp.tile(tensor_in, repeats)
[docs]
def conditional(self, predicate, true_callable, false_callable):
"""
Runs a callable conditional on the boolean value of the evaluation of a predicate
Example:
>>> import pyhf
>>> pyhf.set_backend("jax")
>>> tensorlib = pyhf.tensorlib
>>> a = tensorlib.astensor([4])
>>> b = tensorlib.astensor([5])
>>> tensorlib.conditional((a < b)[0], lambda: a + b, lambda: a - b)
Array([9.], dtype=float64)
Args:
predicate (:obj:`scalar`): The logical condition that determines which callable to evaluate
true_callable (:obj:`callable`): The callable that is evaluated when the :code:`predicate` evaluates to :code:`true`
false_callable (:obj:`callable`): The callable that is evaluated when the :code:`predicate` evaluates to :code:`false`
Returns:
JAX ndarray: The output of the callable that was evaluated
"""
return true_callable() if predicate else false_callable()
[docs]
def tolist(self, tensor_in):
try:
return jnp.asarray(tensor_in).tolist()
except (TypeError, ValueError):
if isinstance(tensor_in, list):
return tensor_in
raise
[docs]
def outer(self, tensor_in_1, tensor_in_2):
return jnp.outer(tensor_in_1, tensor_in_2)
[docs]
def gather(self, tensor, indices):
return tensor[indices]
[docs]
def boolean_mask(self, tensor, mask):
return tensor[mask]
[docs]
def isfinite(self, tensor):
return jnp.isfinite(tensor)
[docs]
def astensor(self, tensor_in, dtype="float"):
"""
Convert to a JAX ndarray.
Example:
>>> import pyhf
>>> pyhf.set_backend("jax")
>>> tensor = pyhf.tensorlib.astensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
>>> tensor
Array([[1., 2., 3.],
[4., 5., 6.]], dtype=float64)
>>> type(tensor) # doctest:+ELLIPSIS
<class '...ArrayImpl'>
Args:
tensor_in (Number or Tensor): Tensor object
Returns:
`jaxlib.xla_extension.ArrayImpl`: A multi-dimensional, fixed-size homogeneous array.
"""
# TODO: Remove doctest:+ELLIPSIS when JAX API stabilized
try:
dtype = self.dtypemap[dtype]
except KeyError:
log.error(
'Invalid dtype: dtype must be float, int, or bool.', exc_info=True
)
raise
return jnp.asarray(tensor_in, dtype=dtype)
[docs]
def sum(self, tensor_in, axis=None):
return jnp.sum(tensor_in, axis=axis)
[docs]
def product(self, tensor_in, axis=None):
return jnp.prod(tensor_in, axis=axis)
[docs]
def abs(self, tensor):
return jnp.abs(tensor)
[docs]
def ones(self, shape, dtype="float"):
try:
dtype = self.dtypemap[dtype]
except KeyError:
log.error(
f"Invalid dtype: dtype must be one of {list(self.dtypemap)}.",
exc_info=True,
)
raise
return jnp.ones(shape, dtype=dtype)
[docs]
def zeros(self, shape, dtype="float"):
try:
dtype = self.dtypemap[dtype]
except KeyError:
log.error(
f"Invalid dtype: dtype must be one of {list(self.dtypemap)}.",
exc_info=True,
)
raise
return jnp.zeros(shape, dtype=dtype)
[docs]
def power(self, tensor_in_1, tensor_in_2):
return jnp.power(tensor_in_1, tensor_in_2)
[docs]
def sqrt(self, tensor_in):
return jnp.sqrt(tensor_in)
[docs]
def divide(self, tensor_in_1, tensor_in_2):
return jnp.divide(tensor_in_1, tensor_in_2)
[docs]
def log(self, tensor_in):
return jnp.log(tensor_in)
[docs]
def exp(self, tensor_in):
return jnp.exp(tensor_in)
[docs]
def percentile(self, tensor_in, q, axis=None, interpolation="linear"):
r"""
Compute the :math:`q`-th percentile of the tensor along the specified axis.
Example:
>>> import pyhf
>>> import jax.numpy as jnp
>>> pyhf.set_backend("jax")
>>> a = pyhf.tensorlib.astensor([[10, 7, 4], [3, 2, 1]])
>>> pyhf.tensorlib.percentile(a, 50)
Array(3.5, dtype=float64)
>>> pyhf.tensorlib.percentile(a, 50, axis=1)
Array([7., 2.], dtype=float64)
Args:
tensor_in (`tensor`): The tensor containing the data
q (:obj:`float` or `tensor`): The :math:`q`-th percentile to compute
axis (`number` or `tensor`): The dimensions along which to compute
interpolation (:obj:`str`): The interpolation method to use when the
desired percentile lies between two data points ``i < j``:
- ``'linear'``: ``i + (j - i) * fraction``, where ``fraction`` is the
fractional part of the index surrounded by ``i`` and ``j``.
- ``'lower'``: ``i``.
- ``'higher'``: ``j``.
- ``'midpoint'``: ``(i + j) / 2``.
- ``'nearest'``: ``i`` or ``j``, whichever is nearest.
Returns:
JAX ndarray: The value of the :math:`q`-th percentile of the tensor along the specified axis.
.. versionadded:: 0.7.0
"""
return jnp.percentile(tensor_in, q, axis=axis, method=interpolation)
[docs]
def stack(self, sequence, axis=0):
return jnp.stack(sequence, axis=axis)
[docs]
def where(self, mask, tensor_in_1, tensor_in_2):
return jnp.where(mask, tensor_in_1, tensor_in_2)
[docs]
def concatenate(self, sequence, axis=0):
"""
Join a sequence of arrays along an existing axis.
Args:
sequence: sequence of tensors
axis: dimension along which to concatenate
Returns:
output: the concatenated tensor
"""
return jnp.concatenate(sequence, axis=axis)
[docs]
def simple_broadcast(self, *args):
"""
Broadcast a sequence of 1 dimensional arrays.
Example:
>>> import pyhf
>>> pyhf.set_backend("jax")
>>> pyhf.tensorlib.simple_broadcast(
... pyhf.tensorlib.astensor([1]),
... pyhf.tensorlib.astensor([2, 3, 4]),
... pyhf.tensorlib.astensor([5, 6, 7]))
[Array([1., 1., 1.], dtype=float64), Array([2., 3., 4.], dtype=float64), Array([5., 6., 7.], dtype=float64)]
Args:
args (Array of Tensors): Sequence of arrays
Returns:
list of Tensors: The sequence broadcast together.
"""
return jnp.broadcast_arrays(*args)
[docs]
def shape(self, tensor):
return tensor.shape
[docs]
def reshape(self, tensor, newshape):
return jnp.reshape(tensor, newshape)
[docs]
def ravel(self, tensor):
"""
Return a flattened view of the tensor, not a copy.
Example:
>>> import pyhf
>>> pyhf.set_backend("jax")
>>> tensor = pyhf.tensorlib.astensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
>>> pyhf.tensorlib.ravel(tensor)
Array([1., 2., 3., 4., 5., 6.], dtype=float64)
Args:
tensor (Tensor): Tensor object
Returns:
`jaxlib.xla_extension.Array`: A flattened array.
"""
return jnp.ravel(tensor)
[docs]
def einsum(self, subscripts, *operands):
"""
Evaluates the Einstein summation convention on the operands.
Using the Einstein summation convention, many common multi-dimensional
array operations can be represented in a simple fashion. This function
provides a way to compute such summations. The best way to understand
this function is to try the examples below, which show how many common
NumPy functions can be implemented as calls to einsum.
Args:
subscripts: str, specifies the subscripts for summation
operands: list of array_like, these are the tensors for the operation
Returns:
tensor: the calculation based on the Einstein summation convention
"""
# return contract(subscripts,*operands)
return jnp.einsum(subscripts, *operands)
[docs]
def poisson_logpdf(self, n, lam):
n = jnp.asarray(n)
lam = jnp.asarray(lam)
return xlogy(n, lam) - lam - gammaln(n + 1.0)
[docs]
def poisson(self, n, lam):
r"""
The continuous approximation, using :math:`n! = \Gamma\left(n+1\right)`,
to the probability mass function of the Poisson distribution evaluated
at :code:`n` given the parameter :code:`lam`.
.. note::
Though the p.m.f of the Poisson distribution is not defined for
:math:`\lambda = 0`, the limit as :math:`\lambda \to 0` is still
defined, which gives a degenerate p.m.f. of
.. math::
\lim_{\lambda \to 0} \,\mathrm{Pois}(n | \lambda) =
\left\{\begin{array}{ll}
1, & n = 0,\\
0, & n > 0
\end{array}\right.
Example:
>>> import pyhf
>>> pyhf.set_backend("jax")
>>> pyhf.tensorlib.poisson(5., 6.)
Array(0.16062314, dtype=float64, weak_type=True)
>>> values = pyhf.tensorlib.astensor([5., 9.])
>>> rates = pyhf.tensorlib.astensor([6., 8.])
>>> pyhf.tensorlib.poisson(values, rates)
Array([0.16062314, 0.12407692], dtype=float64)
Args:
n (:obj:`tensor` or :obj:`float`): The value at which to evaluate the approximation to the Poisson distribution p.m.f.
(the observed number of events)
lam (:obj:`tensor` or :obj:`float`): The mean of the Poisson distribution p.m.f.
(the expected number of events)
Returns:
JAX ndarray: Value of the continuous approximation to Poisson(n|lam)
"""
n = jnp.asarray(n)
lam = jnp.asarray(lam)
return jnp.exp(xlogy(n, lam) - lam - gammaln(n + 1.0))
[docs]
def normal_logpdf(self, x, mu, sigma):
# this is much faster than
# norm.logpdf(x, loc=mu, scale=sigma)
# https://codereview.stackexchange.com/questions/69718/fastest-computation-of-n-likelihoods-on-normal-distributions
root2 = jnp.sqrt(2)
root2pi = jnp.sqrt(2 * jnp.pi)
prefactor = -jnp.log(sigma * root2pi)
summand = -jnp.square(jnp.divide((x - mu), (root2 * sigma)))
return prefactor + summand
# def normal_logpdf(self, x, mu, sigma):
# return norm.logpdf(x, loc=mu, scale=sigma)
[docs]
def normal(self, x, mu, sigma):
r"""
The probability density function of the Normal distribution evaluated
at :code:`x` given parameters of mean of :code:`mu` and standard deviation
of :code:`sigma`.
Example:
>>> import pyhf
>>> pyhf.set_backend("jax")
>>> pyhf.tensorlib.normal(0.5, 0., 1.)
Array(0.35206533, dtype=float64, weak_type=True)
>>> values = pyhf.tensorlib.astensor([0.5, 2.0])
>>> means = pyhf.tensorlib.astensor([0., 2.3])
>>> sigmas = pyhf.tensorlib.astensor([1., 0.8])
>>> pyhf.tensorlib.normal(values, means, sigmas)
Array([0.35206533, 0.46481887], dtype=float64)
Args:
x (:obj:`tensor` or :obj:`float`): The value at which to evaluate the Normal distribution p.d.f.
mu (:obj:`tensor` or :obj:`float`): The mean of the Normal distribution
sigma (:obj:`tensor` or :obj:`float`): The standard deviation of the Normal distribution
Returns:
JAX ndarray: Value of Normal(x|mu, sigma)
"""
return norm.pdf(x, loc=mu, scale=sigma)
[docs]
def normal_cdf(self, x, mu=0, sigma=1):
"""
The cumulative distribution function for the Normal distribution
Example:
>>> import pyhf
>>> pyhf.set_backend("jax")
>>> pyhf.tensorlib.normal_cdf(0.8)
Array(0.7881446, dtype=float64)
>>> values = pyhf.tensorlib.astensor([0.8, 2.0])
>>> pyhf.tensorlib.normal_cdf(values)
Array([0.7881446 , 0.97724987], dtype=float64)
Args:
x (:obj:`tensor` or :obj:`float`): The observed value of the random variable to evaluate the CDF for
mu (:obj:`tensor` or :obj:`float`): The mean of the Normal distribution
sigma (:obj:`tensor` or :obj:`float`): The standard deviation of the Normal distribution
Returns:
JAX ndarray: The CDF
"""
return norm.cdf(x, loc=mu, scale=sigma)
[docs]
def poisson_dist(self, rate):
r"""
The Poisson distribution with rate parameter :code:`rate`.
Example:
>>> import pyhf
>>> pyhf.set_backend("jax")
>>> rates = pyhf.tensorlib.astensor([5, 8])
>>> values = pyhf.tensorlib.astensor([4, 9])
>>> poissons = pyhf.tensorlib.poisson_dist(rates)
>>> poissons.log_prob(values)
Array([-1.74030218, -2.0868536 ], dtype=float64)
Args:
rate (:obj:`tensor` or :obj:`float`): The mean of the Poisson distribution (the expected number of events)
Returns:
Poisson distribution: The Poisson distribution class
"""
return _BasicPoisson(rate)
[docs]
def normal_dist(self, mu, sigma):
r"""
The Normal distribution with mean :code:`mu` and standard deviation :code:`sigma`.
Example:
>>> import pyhf
>>> pyhf.set_backend("jax")
>>> means = pyhf.tensorlib.astensor([5, 8])
>>> stds = pyhf.tensorlib.astensor([1, 0.5])
>>> values = pyhf.tensorlib.astensor([4, 9])
>>> normals = pyhf.tensorlib.normal_dist(means, stds)
>>> normals.log_prob(values)
Array([-1.41893853, -2.22579135], dtype=float64)
Args:
mu (:obj:`tensor` or :obj:`float`): The mean of the Normal distribution
sigma (:obj:`tensor` or :obj:`float`): The standard deviation of the Normal distribution
Returns:
Normal distribution: The Normal distribution class
"""
return _BasicNormal(mu, sigma)
[docs]
def to_numpy(self, tensor_in):
"""
Convert the JAX tensor to a :class:`numpy.ndarray`.
Example:
>>> import pyhf
>>> pyhf.set_backend("jax")
>>> tensor = pyhf.tensorlib.astensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
>>> tensor
Array([[1., 2., 3.],
[4., 5., 6.]], dtype=float64)
>>> numpy_ndarray = pyhf.tensorlib.to_numpy(tensor)
>>> numpy_ndarray
array([[1., 2., 3.],
[4., 5., 6.]])
>>> type(numpy_ndarray)
<class 'numpy.ndarray'>
Args:
tensor_in (:obj:`tensor`): The input tensor object.
Returns:
:class:`numpy.ndarray`: The tensor converted to a NumPy ``ndarray``.
"""
return np.asarray(tensor_in, dtype=tensor_in.dtype)
[docs]
def transpose(self, tensor_in):
"""
Transpose the tensor.
Example:
>>> import pyhf
>>> pyhf.set_backend("jax")
>>> tensor = pyhf.tensorlib.astensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
>>> tensor
Array([[1., 2., 3.],
[4., 5., 6.]], dtype=float64)
>>> pyhf.tensorlib.transpose(tensor)
Array([[1., 4.],
[2., 5.],
[3., 6.]], dtype=float64)
Args:
tensor_in (:obj:`tensor`): The input tensor object.
Returns:
JAX ndarray: The transpose of the input tensor.
.. versionadded:: 0.7.0
"""
return tensor_in.transpose()