Source code for hepstats.utils.fit.api_check

"""
Module for testing a fitting library validity with hepstats.

A fitting library should provide six basic objects:

    * model / probability density function
    * parameters of the models
    * data
    * loss / likelihood function
    * minimizer
    * fitresult (optional)

A function for each object is defined in this module, all should return `True` to work
with hepstats.

The `zfit` API is currently the standard fitting API in hepstats.

"""

from __future__ import annotations

import warnings

import uhi.typing.plottable


[docs] def is_valid_parameter(object): """ Checks if a parameter has the following attributes/methods: * value * set_value * floating """ has_value = hasattr(object, "value") has_set_value = hasattr(object, "set_value") has_floating = hasattr(object, "floating") return has_value and has_set_value and has_floating
[docs] def is_valid_data(object): """ Checks if the data object has the following attributes/methods: * nevents * weights * set_weights * space """ is_sampled_data = hasattr(object, "resample") try: has_nevents = hasattr(object, "nevents") except RuntimeError: if is_sampled_data: object.resample() has_nevents = hasattr(object, "nevents") else: has_nevents = False has_weights = hasattr(object, "weights") has_set_weights = hasattr(object, "set_weights") has_space = hasattr(object, "space") is_histlike = isinstance(object, uhi.typing.plottable.PlottableHistogram) return (has_nevents and has_weights and has_set_weights and has_space) or is_histlike
[docs] def is_valid_pdf(object): """ Checks if the pdf object has the following attributes/methods: * get_params * pdf * integrate * sample * get_yield Also the function **is_valid_parameter** is called with each of the parameters returned by get_params as argument. """ has_get_params = hasattr(object, "get_params") if not has_get_params: return False else: params = object.get_params() all_valid_params = all(is_valid_parameter(p) for p in params) has_pdf = hasattr(object, "pdf") has_integrate = hasattr(object, "integrate") has_sample = hasattr(object, "sample") has_space = hasattr(object, "space") has_get_yield = hasattr(object, "get_yield") return all_valid_params and has_pdf and has_integrate and has_sample and has_space and has_get_yield
[docs] def is_valid_loss(object): """ Checks if the loss object has the following attributes/methods: * model * data * get_params * constraints * fit_range Also the function **is_valid_pdf** is called with each of the models returned by model as argument. Additionnally the function **is_valid_data** is called with each of the data objects return by data as argument. """ if not hasattr(object, "model"): return False else: model = object.model if not hasattr(object, "data"): return False else: data = object.data has_get_params = hasattr(object, "get_params") has_constraints = hasattr(object, "constraints") has_create_new = hasattr(object, "create_new") if not has_create_new: warnings.warn("Loss should have a `create_new` method.", FutureWarning, stacklevel=3) has_create_new = True # TODO: allowed now, will be dropped in the future all_valid_pdfs = all(is_valid_pdf(m) for m in model) all_valid_datasets = all(is_valid_data(d) for d in data) return all_valid_pdfs and all_valid_datasets and has_constraints and has_create_new and has_get_params
[docs] def is_valid_fitresult(object): """ Checks if the fit result object has the following attributes/methods: * loss * params * covariance Also the function **is_valid_loss** is called with the loss as argument. """ has_loss = hasattr(object, "loss") if not has_loss: return False else: loss = object.loss has_params = hasattr(object, "params") has_covariance = hasattr(object, "covariance") return is_valid_loss(loss) and has_params and has_covariance
[docs] def is_valid_minimizer(object): """ Checks if the minimzer object has the following attributes/methods: * minimize """ return hasattr(object, "minimize")