jax_backend#

class pyhf.tensor.jax_backend.jax_backend(**kwargs)[source]#

Bases: object

JAX backend for pyhf

__init__(**kwargs)[source]#

Attributes

name#
precision#
dtypemap#
default_do_grad#
array_subtype#

The array content type for jax

array_type#

The array type for jax

Methods

_setup()[source]#

Run any global setups for the jax lib.

abs(tensor)[source]#
astensor(tensor_in, dtype='float')[source]#

Convert to a JAX ndarray.

Example

>>> import pyhf
>>> pyhf.set_backend("jax")
>>> tensor = pyhf.tensorlib.astensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
>>> tensor
Array([[1., 2., 3.],
       [4., 5., 6.]], dtype=float64)
>>> type(tensor) 
<class '...ArrayImpl'>
Parameters:

tensor_in (Number or Tensor) – Tensor object

Returns:

A multi-dimensional, fixed-size homogeneous array.

Return type:

jaxlib.xla_extension.ArrayImpl

boolean_mask(tensor, mask)[source]#
clip(tensor_in, min_value, max_value)[source]#

Clips (limits) the tensor values to be within a specified min and max.

Example

>>> import pyhf
>>> pyhf.set_backend("jax")
>>> a = pyhf.tensorlib.astensor([-2, -1, 0, 1, 2])
>>> pyhf.tensorlib.clip(a, -1, 1)
Array([-1., -1.,  0.,  1.,  1.], dtype=float64)
Parameters:
  • tensor_in (tensor) – The input tensor object

  • min_value (scalar or tensor or None) – The minimum value to be clipped to

  • max_value (scalar or tensor or None) – The maximum value to be clipped to

Returns:

A clipped tensor

Return type:

JAX ndarray

concatenate(sequence, axis=0)[source]#

Join a sequence of arrays along an existing axis.

Parameters:
  • sequence – sequence of tensors

  • axis – dimension along which to concatenate

Returns:

the concatenated tensor

Return type:

output

conditional(predicate, true_callable, false_callable)[source]#

Runs a callable conditional on the boolean value of the evaluation of a predicate

Example

>>> import pyhf
>>> pyhf.set_backend("jax")
>>> tensorlib = pyhf.tensorlib
>>> a = tensorlib.astensor([4])
>>> b = tensorlib.astensor([5])
>>> tensorlib.conditional((a < b)[0], lambda: a + b, lambda: a - b)
Array([9.], dtype=float64)
Parameters:
  • predicate (scalar) – The logical condition that determines which callable to evaluate

  • true_callable (callable) – The callable that is evaluated when the predicate evaluates to true

  • false_callable (callable) – The callable that is evaluated when the predicate evaluates to false

Returns:

The output of the callable that was evaluated

Return type:

JAX ndarray

divide(tensor_in_1, tensor_in_2)[source]#
einsum(subscripts, *operands)[source]#

Evaluates the Einstein summation convention on the operands.

Using the Einstein summation convention, many common multi-dimensional array operations can be represented in a simple fashion. This function provides a way to compute such summations. The best way to understand this function is to try the examples below, which show how many common NumPy functions can be implemented as calls to einsum.

Parameters:
  • subscripts – str, specifies the subscripts for summation

  • operands – list of array_like, these are the tensors for the operation

Returns:

the calculation based on the Einstein summation convention

Return type:

tensor

erf(tensor_in)[source]#

The error function of complex argument.

Example

>>> import pyhf
>>> pyhf.set_backend("jax")
>>> a = pyhf.tensorlib.astensor([-2., -1., 0., 1., 2.])
>>> pyhf.tensorlib.erf(a)
Array([-0.99532227, -0.84270079,  0.        ,  0.84270079,  0.99532227],      dtype=float64)
Parameters:

tensor_in (tensor) – The input tensor object

Returns:

The values of the error function at the given points.

Return type:

JAX ndarray

erfinv(tensor_in)[source]#

The inverse of the error function of complex argument.

Example

>>> import pyhf
>>> pyhf.set_backend("jax")
>>> a = pyhf.tensorlib.astensor([-2., -1., 0., 1., 2.])
>>> pyhf.tensorlib.erfinv(pyhf.tensorlib.erf(a))
Array([-2., -1.,  0.,  1.,  2.], dtype=float64)
Parameters:

tensor_in (tensor) – The input tensor object

Returns:

The values of the inverse of the error function at the given points.

Return type:

JAX ndarray

exp(tensor_in)[source]#
gather(tensor, indices)[source]#
isfinite(tensor)[source]#
log(tensor_in)[source]#
normal(x, mu, sigma)[source]#

The probability density function of the Normal distribution evaluated at x given parameters of mean of mu and standard deviation of sigma.

Example

>>> import pyhf
>>> pyhf.set_backend("jax")
>>> pyhf.tensorlib.normal(0.5, 0., 1.)
Array(0.35206533, dtype=float64, weak_type=True)
>>> values = pyhf.tensorlib.astensor([0.5, 2.0])
>>> means = pyhf.tensorlib.astensor([0., 2.3])
>>> sigmas = pyhf.tensorlib.astensor([1., 0.8])
>>> pyhf.tensorlib.normal(values, means, sigmas)
Array([0.35206533, 0.46481887], dtype=float64)
Parameters:
  • x (tensor or float) – The value at which to evaluate the Normal distribution p.d.f.

  • mu (tensor or float) – The mean of the Normal distribution

  • sigma (tensor or float) – The standard deviation of the Normal distribution

Returns:

Value of Normal(x|mu, sigma)

Return type:

JAX ndarray

normal_cdf(x, mu=0, sigma=1)[source]#

The cumulative distribution function for the Normal distribution

Example

>>> import pyhf
>>> pyhf.set_backend("jax")
>>> pyhf.tensorlib.normal_cdf(0.8)
Array(0.7881446, dtype=float64)
>>> values = pyhf.tensorlib.astensor([0.8, 2.0])
>>> pyhf.tensorlib.normal_cdf(values)
Array([0.7881446 , 0.97724987], dtype=float64)
Parameters:
  • x (tensor or float) – The observed value of the random variable to evaluate the CDF for

  • mu (tensor or float) – The mean of the Normal distribution

  • sigma (tensor or float) – The standard deviation of the Normal distribution

Returns:

The CDF

Return type:

JAX ndarray

normal_dist(mu, sigma)[source]#

The Normal distribution with mean mu and standard deviation sigma.

Example

>>> import pyhf
>>> pyhf.set_backend("jax")
>>> means = pyhf.tensorlib.astensor([5, 8])
>>> stds = pyhf.tensorlib.astensor([1, 0.5])
>>> values = pyhf.tensorlib.astensor([4, 9])
>>> normals = pyhf.tensorlib.normal_dist(means, stds)
>>> normals.log_prob(values)
Array([-1.41893853, -2.22579135], dtype=float64)
Parameters:
  • mu (tensor or float) – The mean of the Normal distribution

  • sigma (tensor or float) – The standard deviation of the Normal distribution

Returns:

The Normal distribution class

Return type:

Normal distribution

normal_logpdf(x, mu, sigma)[source]#
ones(shape, dtype='float')[source]#
outer(tensor_in_1, tensor_in_2)[source]#
percentile(tensor_in, q, axis=None, interpolation='linear')[source]#

Compute the \(q\)-th percentile of the tensor along the specified axis.

Example

>>> import pyhf
>>> import jax.numpy as jnp
>>> pyhf.set_backend("jax")
>>> a = pyhf.tensorlib.astensor([[10, 7, 4], [3, 2, 1]])
>>> pyhf.tensorlib.percentile(a, 50)
Array(3.5, dtype=float64)
>>> pyhf.tensorlib.percentile(a, 50, axis=1)
Array([7., 2.], dtype=float64)
Parameters:
  • tensor_in (tensor) – The tensor containing the data

  • q (float or tensor) – The \(q\)-th percentile to compute

  • axis (number or tensor) – The dimensions along which to compute

  • interpolation (str) –

    The interpolation method to use when the desired percentile lies between two data points i < j:

    • 'linear': i + (j - i) * fraction, where fraction is the fractional part of the index surrounded by i and j.

    • 'lower': i.

    • 'higher': j.

    • 'midpoint': (i + j) / 2.

    • 'nearest': i or j, whichever is nearest.

Returns:

The value of the \(q\)-th percentile of the tensor along the specified axis.

Return type:

JAX ndarray

Added in version 0.7.0.

poisson(n, lam)[source]#

The continuous approximation, using \(n! = \Gamma\left(n+1\right)\), to the probability mass function of the Poisson distribution evaluated at n given the parameter lam.

Note

Though the p.m.f of the Poisson distribution is not defined for \(\lambda = 0\), the limit as \(\lambda \to 0\) is still defined, which gives a degenerate p.m.f. of

\[\begin{split}\lim_{\lambda \to 0} \,\mathrm{Pois}(n | \lambda) = \left\{\begin{array}{ll} 1, & n = 0,\\ 0, & n > 0 \end{array}\right.\end{split}\]

Example

>>> import pyhf
>>> pyhf.set_backend("jax")
>>> pyhf.tensorlib.poisson(5., 6.)
Array(0.16062314, dtype=float64, weak_type=True)
>>> values = pyhf.tensorlib.astensor([5., 9.])
>>> rates = pyhf.tensorlib.astensor([6., 8.])
>>> pyhf.tensorlib.poisson(values, rates)
Array([0.16062314, 0.12407692], dtype=float64)
Parameters:
  • n (tensor or float) – The value at which to evaluate the approximation to the Poisson distribution p.m.f. (the observed number of events)

  • lam (tensor or float) – The mean of the Poisson distribution p.m.f. (the expected number of events)

Returns:

Value of the continuous approximation to Poisson(n|lam)

Return type:

JAX ndarray

poisson_dist(rate)[source]#

The Poisson distribution with rate parameter rate.

Example

>>> import pyhf
>>> pyhf.set_backend("jax")
>>> rates = pyhf.tensorlib.astensor([5, 8])
>>> values = pyhf.tensorlib.astensor([4, 9])
>>> poissons = pyhf.tensorlib.poisson_dist(rates)
>>> poissons.log_prob(values)
Array([-1.74030218, -2.0868536 ], dtype=float64)
Parameters:

rate (tensor or float) – The mean of the Poisson distribution (the expected number of events)

Returns:

The Poisson distribution class

Return type:

Poisson distribution

poisson_logpdf(n, lam)[source]#
power(tensor_in_1, tensor_in_2)[source]#
product(tensor_in, axis=None)[source]#
ravel(tensor)[source]#

Return a flattened view of the tensor, not a copy.

Example

>>> import pyhf
>>> pyhf.set_backend("jax")
>>> tensor = pyhf.tensorlib.astensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
>>> pyhf.tensorlib.ravel(tensor)
Array([1., 2., 3., 4., 5., 6.], dtype=float64)
Parameters:

tensor (Tensor) – Tensor object

Returns:

A flattened array.

Return type:

jaxlib.xla_extension.Array

reshape(tensor, newshape)[source]#
shape(tensor)[source]#
simple_broadcast(*args)[source]#

Broadcast a sequence of 1 dimensional arrays.

Example

>>> import pyhf
>>> pyhf.set_backend("jax")
>>> pyhf.tensorlib.simple_broadcast(
...   pyhf.tensorlib.astensor([1]),
...   pyhf.tensorlib.astensor([2, 3, 4]),
...   pyhf.tensorlib.astensor([5, 6, 7]))
[Array([1., 1., 1.], dtype=float64), Array([2., 3., 4.], dtype=float64), Array([5., 6., 7.], dtype=float64)]
Parameters:

args (Array of Tensors) – Sequence of arrays

Returns:

The sequence broadcast together.

Return type:

list of Tensors

sqrt(tensor_in)[source]#
stack(sequence, axis=0)[source]#
sum(tensor_in, axis=None)[source]#
tile(tensor_in, repeats)[source]#

Repeat tensor data along a specific dimension

Example

>>> import pyhf
>>> pyhf.set_backend("jax")
>>> a = pyhf.tensorlib.astensor([[1.0], [2.0]])
>>> pyhf.tensorlib.tile(a, (1, 2))
Array([[1., 1.],
       [2., 2.]], dtype=float64)
Parameters:
  • tensor_in (tensor) – The tensor to be repeated

  • repeats (tensor) – The tuple of multipliers for each dimension

Returns:

The tensor with repeated axes

Return type:

JAX ndarray

to_numpy(tensor_in)[source]#

Convert the JAX tensor to a numpy.ndarray.

Example

>>> import pyhf
>>> pyhf.set_backend("jax")
>>> tensor = pyhf.tensorlib.astensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
>>> tensor
Array([[1., 2., 3.],
       [4., 5., 6.]], dtype=float64)
>>> numpy_ndarray = pyhf.tensorlib.to_numpy(tensor)
>>> numpy_ndarray
array([[1., 2., 3.],
       [4., 5., 6.]])
>>> type(numpy_ndarray)
<class 'numpy.ndarray'>
Parameters:

tensor_in (tensor) – The input tensor object.

Returns:

The tensor converted to a NumPy ndarray.

Return type:

numpy.ndarray

tolist(tensor_in)[source]#
transpose(tensor_in)[source]#

Transpose the tensor.

Example

>>> import pyhf
>>> pyhf.set_backend("jax")
>>> tensor = pyhf.tensorlib.astensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
>>> tensor
Array([[1., 2., 3.],
       [4., 5., 6.]], dtype=float64)
>>> pyhf.tensorlib.transpose(tensor)
Array([[1., 4.],
       [2., 5.],
       [3., 6.]], dtype=float64)
Parameters:

tensor_in (tensor) – The input tensor object.

Returns:

The transpose of the input tensor.

Return type:

JAX ndarray

Added in version 0.7.0.

where(mask, tensor_in_1, tensor_in_2)[source]#
zeros(shape, dtype='float')[source]#