from __future__ import annotations
import warnings
from typing import Any
import numpy as np
from ..utils import eval_pdf
from ..utils.fit.api_check import is_valid_pdf
from .exceptions import ModelNotFittedToData
from .warnings import AboveToleranceWarning
[docs]
def is_sum_of_extended_pdfs(model) -> bool:
"""Checks if the input model is a sum of extended models.
Args:
model: the input model/pdf
Returns:
True if the model is a sum of extended models, False if not.
"""
if not hasattr(model, "get_models"):
return False
return all(m.is_extended for m in model.get_models()) and model.is_extended
[docs]
def compute_sweights(model, x: np.ndarray) -> dict[Any, np.ndarray]:
"""Computes sWeights from probability density functions for different components/species in a fit model
(for instance signal and background) fitted on some data `x`.
i.e. model = Nsig * pdf_signal + Nbkg * pdf_bkg
Args:
model: sum of extended pdfs.
x: data on which `model` is fitted
Returns:
dictionary with yield parameters as keys, and sWeights for correspoind species as values.
Example with **zfit**:
Imports:
>>> import numpy as np
>>> import zfit
>>> from zfit.loss import ExtendedUnbinnedNLL
>>> from zfit.minimize import Minuit
Definition of the bounds and yield of background and signal species:
>>> bounds = (0.0, 3.0)
>>> nbkg = 10000
>>> nsig = 5000
>>> obs = zfit.Space('x', limits=bounds)
Generation of data:
>>> bkg = np.random.exponential(0.5, nbkg)
>>> peak = np.random.normal(1.2, 0.1, nsig)
>>> data = np.concatenate((bkg, peak))
>>> data = data[(data > bounds[0]) & (data < bounds[1])]
>>> N = data.size
>>> data = zfit.data.Data.from_numpy(obs=obs, array=data)
Model definition:
>>> mean = zfit.Parameter("mean", 1.2, 0.5, 2.0)
>>> sigma = zfit.Parameter("sigma", 0.1, 0.02, 0.2)
>>> lambda_ = zfit.Parameter("lambda", -2.0, -4.0, -1.0)
>>> Nsig = zfit.Parameter("Nsig", nsig, 0., N)
>>> Nbkg = zfit.Parameter("Nbkg", nbkg, 0., N)
>>> signal = zfit.pdf.Gauss(obs=obs, mu=mean, sigma=sigma).create_extended(Nsig)
>>> background = zfit.pdf.Exponential(obs=obs, lambda_=lambda_).create_extended(Nbkg)
>>> tot_model = zfit.pdf.SumPDF([signal, background])
Loss construction and minimization:
>>> loss = ExtendedUnbinnedNLL(model=signal + background, data=data)
>>> minimizer = Minuit()
>>> minimum = minimizer.minimize(loss)
sWeights computation:
>>> from hepstats.splot import compute_sweights
>>> sweights = compute_sweights(tot_model, data)
>>> print(sweights)
{<zfit.Parameter 'Nsig' floating=True value=4985>: array([-0.09953299, -0.09953299, -0.09953299, ...,
0.78689884, 1.08823111, 1.05948873]),
<zfit.Parameter 'Nbkg' floating=True value=9989>: array([ 1.09953348, 1.09953348, 1.09953348, ...,
0.21310097, -0.08823153, -0.05948912])}
"""
if not is_valid_pdf(model):
msg = f"{model} is not a valid pdf!"
raise ValueError(msg)
if not is_sum_of_extended_pdfs(model):
msg = f"Input model, {model}, should be a sum of extended pdfs!"
raise ValueError(msg)
models = model.get_models()
yields = [m.get_yield() for m in models]
p = np.vstack([eval_pdf(m, x) for m in models]).T
Nx = eval_pdf(model, x, allow_extended=True)
pN = p / Nx[:, None]
MLSR = pN.sum(axis=0)
atol_warning = 5e-3
atol_exceptions = 5e-2
def msg_fn(tolerance):
msg = (
"The Maximum Likelihood Sum Rule sanity check, described in equation 17 of"
+ " arXiv:physics/0402083, failed. According to this check the following quantities\n"
)
for y, mlsr in zip(yields, MLSR):
msg += f"\t* {y.name}: {mlsr},\n"
msg += f"should be equal to 1.0 with an absolute tolerance of {tolerance}."
return msg
if not np.allclose(MLSR, 1, atol=atol_exceptions):
msg = msg_fn(atol_exceptions)
msg += " The numbers suggest that the model is not fitted to the data. Please check your fit."
raise ModelNotFittedToData(msg)
if not np.allclose(MLSR, 1, atol=atol_warning):
msg = msg_fn(atol_warning)
msg += " If the fit to the data is good please ignore this warning."
warnings.warn(msg, AboveToleranceWarning, stacklevel=2)
Vinv = (pN).T.dot(pN)
V = np.linalg.inv(Vinv)
sweights = p.dot(V) / Nx[:, None]
return {y: sweights[:, i] for i, y in enumerate(yields)}