jax_backend#
- class pyhf.tensor.jax_backend.jax_backend(**kwargs)[source]#
Bases:
object
JAX backend for pyhf
Attributes
- name#
- precision#
- dtypemap#
- default_do_grad#
- array_subtype#
The array content type for jax
- array_type#
The array type for jax
Methods
- astensor(tensor_in, dtype='float')[source]#
Convert to a JAX ndarray.
Example
>>> import pyhf >>> pyhf.set_backend("jax") >>> tensor = pyhf.tensorlib.astensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) >>> tensor Array([[1., 2., 3.], [4., 5., 6.]], dtype=float64) >>> type(tensor) <class '...ArrayImpl'>
- Parameters:
tensor_in (Number or Tensor) – Tensor object
- Returns:
A multi-dimensional, fixed-size homogeneous array.
- Return type:
jaxlib.xla_extension.ArrayImpl
- clip(tensor_in, min_value, max_value)[source]#
Clips (limits) the tensor values to be within a specified min and max.
Example
>>> import pyhf >>> pyhf.set_backend("jax") >>> a = pyhf.tensorlib.astensor([-2, -1, 0, 1, 2]) >>> pyhf.tensorlib.clip(a, -1, 1) Array([-1., -1., 0., 1., 1.], dtype=float64)
- concatenate(sequence, axis=0)[source]#
Join a sequence of arrays along an existing axis.
- Parameters:
sequence – sequence of tensors
axis – dimension along which to concatenate
- Returns:
the concatenated tensor
- Return type:
output
- conditional(predicate, true_callable, false_callable)[source]#
Runs a callable conditional on the boolean value of the evaluation of a predicate
Example
>>> import pyhf >>> pyhf.set_backend("jax") >>> tensorlib = pyhf.tensorlib >>> a = tensorlib.astensor([4]) >>> b = tensorlib.astensor([5]) >>> tensorlib.conditional((a < b)[0], lambda: a + b, lambda: a - b) Array([9.], dtype=float64)
- Parameters:
- Returns:
The output of the callable that was evaluated
- Return type:
JAX ndarray
- einsum(subscripts, *operands)[source]#
Evaluates the Einstein summation convention on the operands.
Using the Einstein summation convention, many common multi-dimensional array operations can be represented in a simple fashion. This function provides a way to compute such summations. The best way to understand this function is to try the examples below, which show how many common NumPy functions can be implemented as calls to einsum.
- Parameters:
subscripts – str, specifies the subscripts for summation
operands – list of array_like, these are the tensors for the operation
- Returns:
the calculation based on the Einstein summation convention
- Return type:
tensor
- erf(tensor_in)[source]#
The error function of complex argument.
Example
>>> import pyhf >>> pyhf.set_backend("jax") >>> a = pyhf.tensorlib.astensor([-2., -1., 0., 1., 2.]) >>> pyhf.tensorlib.erf(a) Array([-0.99532227, -0.84270079, 0. , 0.84270079, 0.99532227], dtype=float64)
- Parameters:
tensor_in (
tensor
) – The input tensor object- Returns:
The values of the error function at the given points.
- Return type:
JAX ndarray
- erfinv(tensor_in)[source]#
The inverse of the error function of complex argument.
Example
>>> import pyhf >>> pyhf.set_backend("jax") >>> a = pyhf.tensorlib.astensor([-2., -1., 0., 1., 2.]) >>> pyhf.tensorlib.erfinv(pyhf.tensorlib.erf(a)) Array([-2., -1., 0., 1., 2.], dtype=float64)
- Parameters:
tensor_in (
tensor
) – The input tensor object- Returns:
The values of the inverse of the error function at the given points.
- Return type:
JAX ndarray
- normal(x, mu, sigma)[source]#
The probability density function of the Normal distribution evaluated at
x
given parameters of mean ofmu
and standard deviation ofsigma
.Example
>>> import pyhf >>> pyhf.set_backend("jax") >>> pyhf.tensorlib.normal(0.5, 0., 1.) Array(0.35206533, dtype=float64, weak_type=True) >>> values = pyhf.tensorlib.astensor([0.5, 2.0]) >>> means = pyhf.tensorlib.astensor([0., 2.3]) >>> sigmas = pyhf.tensorlib.astensor([1., 0.8]) >>> pyhf.tensorlib.normal(values, means, sigmas) Array([0.35206533, 0.46481887], dtype=float64)
- normal_cdf(x, mu=0, sigma=1)[source]#
The cumulative distribution function for the Normal distribution
Example
>>> import pyhf >>> pyhf.set_backend("jax") >>> pyhf.tensorlib.normal_cdf(0.8) Array(0.7881446, dtype=float64) >>> values = pyhf.tensorlib.astensor([0.8, 2.0]) >>> pyhf.tensorlib.normal_cdf(values) Array([0.7881446 , 0.97724987], dtype=float64)
- normal_dist(mu, sigma)[source]#
The Normal distribution with mean
mu
and standard deviationsigma
.Example
>>> import pyhf >>> pyhf.set_backend("jax") >>> means = pyhf.tensorlib.astensor([5, 8]) >>> stds = pyhf.tensorlib.astensor([1, 0.5]) >>> values = pyhf.tensorlib.astensor([4, 9]) >>> normals = pyhf.tensorlib.normal_dist(means, stds) >>> normals.log_prob(values) Array([-1.41893853, -2.22579135], dtype=float64)
- percentile(tensor_in, q, axis=None, interpolation='linear')[source]#
Compute the \(q\)-th percentile of the tensor along the specified axis.
Example
>>> import pyhf >>> import jax.numpy as jnp >>> pyhf.set_backend("jax") >>> a = pyhf.tensorlib.astensor([[10, 7, 4], [3, 2, 1]]) >>> pyhf.tensorlib.percentile(a, 50) Array(3.5, dtype=float64) >>> pyhf.tensorlib.percentile(a, 50, axis=1) Array([7., 2.], dtype=float64)
- Parameters:
tensor_in (tensor) – The tensor containing the data
q (
float
or tensor) – The \(q\)-th percentile to computeaxis (number or tensor) – The dimensions along which to compute
interpolation (
str
) –The interpolation method to use when the desired percentile lies between two data points
i < j
:'linear'
:i + (j - i) * fraction
, wherefraction
is the fractional part of the index surrounded byi
andj
.'lower'
:i
.'higher'
:j
.'midpoint'
:(i + j) / 2
.'nearest'
:i
orj
, whichever is nearest.
- Returns:
The value of the \(q\)-th percentile of the tensor along the specified axis.
- Return type:
JAX ndarray
Added in version 0.7.0.
- poisson(n, lam)[source]#
The continuous approximation, using \(n! = \Gamma\left(n+1\right)\), to the probability mass function of the Poisson distribution evaluated at
n
given the parameterlam
.Note
Though the p.m.f of the Poisson distribution is not defined for \(\lambda = 0\), the limit as \(\lambda \to 0\) is still defined, which gives a degenerate p.m.f. of
\[\begin{split}\lim_{\lambda \to 0} \,\mathrm{Pois}(n | \lambda) = \left\{\begin{array}{ll} 1, & n = 0,\\ 0, & n > 0 \end{array}\right.\end{split}\]Example
>>> import pyhf >>> pyhf.set_backend("jax") >>> pyhf.tensorlib.poisson(5., 6.) Array(0.16062314, dtype=float64, weak_type=True) >>> values = pyhf.tensorlib.astensor([5., 9.]) >>> rates = pyhf.tensorlib.astensor([6., 8.]) >>> pyhf.tensorlib.poisson(values, rates) Array([0.16062314, 0.12407692], dtype=float64)
- Parameters:
- Returns:
Value of the continuous approximation to Poisson(n|lam)
- Return type:
JAX ndarray
- poisson_dist(rate)[source]#
The Poisson distribution with rate parameter
rate
.Example
>>> import pyhf >>> pyhf.set_backend("jax") >>> rates = pyhf.tensorlib.astensor([5, 8]) >>> values = pyhf.tensorlib.astensor([4, 9]) >>> poissons = pyhf.tensorlib.poisson_dist(rates) >>> poissons.log_prob(values) Array([-1.74030218, -2.0868536 ], dtype=float64)
- Parameters:
rate (
tensor
orfloat
) – The mean of the Poisson distribution (the expected number of events)- Returns:
The Poisson distribution class
- Return type:
Poisson distribution
- ravel(tensor)[source]#
Return a flattened view of the tensor, not a copy.
Example
>>> import pyhf >>> pyhf.set_backend("jax") >>> tensor = pyhf.tensorlib.astensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) >>> pyhf.tensorlib.ravel(tensor) Array([1., 2., 3., 4., 5., 6.], dtype=float64)
- Parameters:
tensor (Tensor) – Tensor object
- Returns:
A flattened array.
- Return type:
jaxlib.xla_extension.Array
- simple_broadcast(*args)[source]#
Broadcast a sequence of 1 dimensional arrays.
Example
>>> import pyhf >>> pyhf.set_backend("jax") >>> pyhf.tensorlib.simple_broadcast( ... pyhf.tensorlib.astensor([1]), ... pyhf.tensorlib.astensor([2, 3, 4]), ... pyhf.tensorlib.astensor([5, 6, 7])) [Array([1., 1., 1.], dtype=float64), Array([2., 3., 4.], dtype=float64), Array([5., 6., 7.], dtype=float64)]
- Parameters:
args (Array of Tensors) – Sequence of arrays
- Returns:
The sequence broadcast together.
- Return type:
list of Tensors
- tile(tensor_in, repeats)[source]#
Repeat tensor data along a specific dimension
Example
>>> import pyhf >>> pyhf.set_backend("jax") >>> a = pyhf.tensorlib.astensor([[1.0], [2.0]]) >>> pyhf.tensorlib.tile(a, (1, 2)) Array([[1., 1.], [2., 2.]], dtype=float64)
- Parameters:
tensor_in (
tensor
) – The tensor to be repeatedrepeats (
tensor
) – The tuple of multipliers for each dimension
- Returns:
The tensor with repeated axes
- Return type:
JAX ndarray
- to_numpy(tensor_in)[source]#
Convert the JAX tensor to a
numpy.ndarray
.Example
>>> import pyhf >>> pyhf.set_backend("jax") >>> tensor = pyhf.tensorlib.astensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) >>> tensor Array([[1., 2., 3.], [4., 5., 6.]], dtype=float64) >>> numpy_ndarray = pyhf.tensorlib.to_numpy(tensor) >>> numpy_ndarray array([[1., 2., 3.], [4., 5., 6.]]) >>> type(numpy_ndarray) <class 'numpy.ndarray'>
- Parameters:
tensor_in (
tensor
) – The input tensor object.- Returns:
The tensor converted to a NumPy
ndarray
.- Return type:
- transpose(tensor_in)[source]#
Transpose the tensor.
Example
>>> import pyhf >>> pyhf.set_backend("jax") >>> tensor = pyhf.tensorlib.astensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) >>> tensor Array([[1., 2., 3.], [4., 5., 6.]], dtype=float64) >>> pyhf.tensorlib.transpose(tensor) Array([[1., 4.], [2., 5.], [3., 6.]], dtype=float64)
- Parameters:
tensor_in (
tensor
) – The input tensor object.- Returns:
The transpose of the input tensor.
- Return type:
JAX ndarray
Added in version 0.7.0.