Tensorizing Interpolators

This notebook will introduce some tensor algebra concepts about being able to convert from calculations inside for-loops into a single calculation over the entire tensor. It is assumed that you have some familiarity with what interpolation functions are used for in pyhf.

To get started, we’ll load up some functions we wrote whose job is to generate sets of histograms and alphas that we will compute interpolations for. This allows us to generate random, structured input data that we can use to test the tensorized form of the interpolation function against the original one we wrote. For now, we will consider only the numpy backend for simplicity, but can replace np to pyhf.tensorlib to achieve identical functionality.

The function random_histosets_alphasets_pair will produce a pair (histogramsets, alphasets) of histograms and alphas for those histograms that represents the type of input we wish to interpolate on.

[1]:
import numpy as np


def random_histosets_alphasets_pair(
    nsysts=150, nhistos_per_syst_upto=300, nalphas=1, nbins_upto=1
):
    def generate_shapes(histogramssets, alphasets):
        h_shape = [len(histogramssets), 0, 0, 0]
        a_shape = (len(alphasets), max(map(len, alphasets)))
        for hs in histogramssets:
            h_shape[1] = max(h_shape[1], len(hs))
            for h in hs:
                h_shape[2] = max(h_shape[2], len(h))
                for sh in h:
                    h_shape[3] = max(h_shape[3], len(sh))
        return tuple(h_shape), a_shape

    def filled_shapes(histogramssets, alphasets):
        # pad our shapes with NaNs
        histos, alphas = generate_shapes(histogramssets, alphasets)
        histos, alphas = np.ones(histos) * np.nan, np.ones(alphas) * np.nan
        for i, syst in enumerate(histogramssets):
            for j, sample in enumerate(syst):
                for k, variation in enumerate(sample):
                    histos[i, j, k, : len(variation)] = variation
        for i, alphaset in enumerate(alphasets):
            alphas[i, : len(alphaset)] = alphaset
        return histos, alphas

    nsyst_histos = np.random.randint(1, 1 + nhistos_per_syst_upto, size=nsysts)
    nhistograms = [np.random.randint(1, nbins_upto + 1, size=n) for n in nsyst_histos]
    random_alphas = [np.random.uniform(-1, 1, size=nalphas) for n in nsyst_histos]

    random_histogramssets = [
        [  # all histos affected by systematic $nh
            [  # sample $i, systematic $nh
                np.random.uniform(10 * i + j, 10 * i + j + 1, size=nbin).tolist()
                for j in range(3)
            ]
            for i, nbin in enumerate(nh)
        ]
        for nh in nhistograms
    ]
    h, a = filled_shapes(random_histogramssets, random_alphas)
    return h, a

The (slow) interpolations

In all cases, the way we do interpolations is as follows:

  1. Loop over both the histogramssets and alphasets simultaneously (e.g. using python’s zip())

  2. Loop over all histograms set in the set of histograms sets that correspond to the histograms affected by a given systematic

  3. Loop over all of the alphas in the set of alphas

  4. Loop over all the bins in the histogram sets simultaneously (e.g. using python’s zip())

  5. Apply the interpolation across the same bin index

This is already exhausting to think about, so let’s put this in code form. Depending on the kind of interpolation being done, we’ll pass in func as an argument to the top-level interpolation loop to switch between linear (interpcode=0) and non-linear (interpcode=1).

[2]:
def interpolation_looper(histogramssets, alphasets, func):
    all_results = []
    for histoset, alphaset in zip(histogramssets, alphasets):
        all_results.append([])
        set_result = all_results[-1]
        for histo in histoset:
            set_result.append([])
            histo_result = set_result[-1]
            for alpha in alphaset:
                alpha_result = []
                for down, nom, up in zip(histo[0], histo[1], histo[2]):
                    v = func(down, nom, up, alpha)
                    alpha_result.append(v)
                histo_result.append(alpha_result)
    return all_results

And we can also define our linear and non-linear interpolations we’ll consider in this notebook that we wish to tensorize.

[3]:
def interpolation_linear(histogramssets, alphasets):
    def summand(down, nom, up, alpha):
        delta_up = up - nom
        delta_down = nom - down
        if alpha > 0:
            delta = delta_up * alpha
        else:
            delta = delta_down * alpha
        return nom + delta

    return interpolation_looper(histogramssets, alphasets, summand)


def interpolation_nonlinear(histogramssets, alphasets):
    def product(down, nom, up, alpha):
        delta_up = up / nom
        delta_down = down / nom
        if alpha > 0:
            delta = delta_up**alpha
        else:
            delta = delta_down ** (-alpha)
        return nom * delta

    return interpolation_looper(histogramssets, alphasets, product)

We will also define a helper function that allows us to pass in two functions we wish to compare the outputs for:

[4]:
def compare_fns(func1, func2):
    h, a = random_histosets_alphasets_pair()

    def _func_runner(func, histssets, alphasets):
        return np.asarray(func(histssets, alphasets))

    old = _func_runner(func1, h, a)
    new = _func_runner(func2, h, a)

    return (np.all(old[~np.isnan(old)] == new[~np.isnan(new)]), (h, a))

For the rest of the notebook, we will detail in explicit form how the linear interpolator gets tensorized, step-by-step. The same sequence of steps will be shown for the non-linear interpolator – but it is left up to the reader to understand the steps.

Tensorizing the Linear Interpolator

Step 0

Step 0 requires converting the innermost conditional check on alpha > 0 into something tensorizable. This also means the calculation itself is going to become tensorized. So we will convert from

if alpha > 0:
    delta =  delta_up*alpha
else:
    delta =  delta_down*alpha

to

delta = np.where(alpha > 0, delta_up*alpha, delta_down*alpha)

Let’s make that change now, and let’s check to make sure we still do the calculation correctly.

[5]:
# get the internal calculation to use tensorlib backend
def new_interpolation_linear_step0(histogramssets, alphasets):
    all_results = []
    for histoset, alphaset in zip(histogramssets, alphasets):
        all_results.append([])
        set_result = all_results[-1]
        for histo in histoset:
            set_result.append([])
            histo_result = set_result[-1]
            for alpha in alphaset:
                alpha_result = []
                for down, nom, up in zip(histo[0], histo[1], histo[2]):
                    delta_up = up - nom
                    delta_down = nom - down
                    delta = np.where(alpha > 0, delta_up * alpha, delta_down * alpha)
                    v = nom + delta
                    alpha_result.append(v)
                histo_result.append(alpha_result)
    return all_results

And does the calculation still match?

[6]:
result, (h, a) = compare_fns(interpolation_linear, new_interpolation_linear_step0)
print(result)
True
[7]:
%%timeit
interpolation_linear(h, a)
189 ms ± 6.14 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
[8]:
%%timeit
new_interpolation_linear_step0(h, a)
255 ms ± 11.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Great! We’re a little bit slower right now, but that’s expected. We’re just getting started.

Step 1

In this step, we would like to remove the innermost zip() call over the histogram bins by calculating the interpolation between the histograms in one fell swoop. This means, instead of writing something like

for down,nom,up in zip(histo[0],histo[1],histo[2]):
    delta_up = up - nom
    ...

one can instead write

delta_up = histo[2] - histo[1]
...

taking advantage of the automatic broadcasting of operations on input tensors. This sort of feature of the tensor backends allows us to speed up code, such as interpolation.

[9]:
# update the delta variations to remove the zip() call and remove most-nested loop
def new_interpolation_linear_step1(histogramssets, alphasets):
    all_results = []
    for histoset, alphaset in zip(histogramssets, alphasets):
        all_results.append([])
        set_result = all_results[-1]
        for histo in histoset:
            set_result.append([])
            histo_result = set_result[-1]
            for alpha in alphaset:
                alpha_result = []
                deltas_up = histo[2] - histo[1]
                deltas_dn = histo[1] - histo[0]
                calc_deltas = np.where(alpha > 0, deltas_up * alpha, deltas_dn * alpha)
                v = histo[1] + calc_deltas
                alpha_result.append(v)
                histo_result.append(alpha_result)
    return all_results

And does the calculation still match?

[10]:
result, (h, a) = compare_fns(interpolation_linear, new_interpolation_linear_step1)
print(result)
True
[11]:
%%timeit
interpolation_linear(h, a)
188 ms ± 7.14 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
[12]:
%%timeit
new_interpolation_linear_step1(h, a)
492 ms ± 42.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Great!

Step 2

In this step, we would like to move the giant array of the deltas calculated to the beginning – outside of all loops – and then only take a subset of it for the calculation itself. This allows us to figure out the entire structure of the input for the rest of the calculations as we slowly move towards including einsum() calls (einstein summation). This means we would like to go from

for histo in histoset:
    delta_up = histo[2] - histo[1]
...

to

all_deltas = ...
for nh, histo in enumerate(histoset):
    deltas = all_deltas[nh]
    ...

Again, we are taking advantage of the automatic broadcasting of operations on input tensors to calculate all the deltas in a single action.

[13]:
# figure out the giant array of all deltas at the beginning and only take subsets of it for the calculation
def new_interpolation_linear_step2(histogramssets, alphasets):
    all_results = []

    allset_all_histo_deltas_up = histogramssets[:, :, 2] - histogramssets[:, :, 1]
    allset_all_histo_deltas_dn = histogramssets[:, :, 1] - histogramssets[:, :, 0]

    for nset, (histoset, alphaset) in enumerate(zip(histogramssets, alphasets)):
        set_result = []

        all_histo_deltas_up = allset_all_histo_deltas_up[nset]
        all_histo_deltas_dn = allset_all_histo_deltas_dn[nset]

        for nh, histo in enumerate(histoset):
            alpha_deltas = []
            for alpha in alphaset:
                alpha_result = []
                deltas_up = all_histo_deltas_up[nh]
                deltas_dn = all_histo_deltas_dn[nh]
                calc_deltas = np.where(alpha > 0, deltas_up * alpha, deltas_dn * alpha)
                alpha_deltas.append(calc_deltas)
            set_result.append([histo[1] + d for d in alpha_deltas])
        all_results.append(set_result)
    return all_results

And does the calculation still match?

[14]:
result, (h, a) = compare_fns(interpolation_linear, new_interpolation_linear_step2)
print(result)
True
[15]:
%%timeit
interpolation_linear(h, a)
179 ms ± 12.4 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
[16]:
%%timeit
new_interpolation_linear_step2(h, a)
409 ms ± 20.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Great!

Step 3

In this step, we get to introduce einstein summation to generalize the calculations we perform across many dimensions in a more concise, straightforward way. See this blog post for some more details on einstein summation notation. In short, it allows us to write

\[c_j = \sum_i \sum_k = A_{ik} B_{kj} \qquad \rightarrow \qquad \texttt{einsum("ij,jk->i", A, B)}\]

in a much more elegant way to express many kinds of common tensor operations such as dot products, transposes, outer products, and so on. This step is generally the hardest as one needs to figure out the corresponding einsum that keeps the calculation preserved (and matching). To some extent it requires a lot of trial and error until you get a feel for how einstein summation notation works.

As a concrete example of a conversion, we wish to go from something like

for nh,histo in enumerate(histoset):
    for alpha in alphaset:
        deltas_up    = all_histo_deltas_up[nh]
        deltas_dn    = all_histo_deltas_dn[nh]
        calc_deltas  = np.where(alpha > 0, deltas_up*alpha, deltas_dn*alpha)
        ...

to get rid of the loop over alpha

for nh,histo in enumerate(histoset):
    alphas_times_deltas_up = np.einsum('i,j->ij',alphaset,all_histo_deltas_up[nh])
    alphas_times_deltas_dn = np.einsum('i,j->ij',alphaset,all_histo_deltas_dn[nh])
    masks = np.einsum('i,j->ij',alphaset > 0,np.ones_like(all_histo_deltas_dn[nh]))

    alpha_deltas  = np.where(masks,alphas_times_deltas_up, alphas_times_deltas_dn)
    ...

In this particular case, we need an outer product that multiplies across the alphaset to the corresponding histoset for the up/down variations. Then we just need to select from either the up variation calculation or the down variation calculation based on the sign of alpha. Try to convince yourself that the einstein summation does what the for-loop does, but a little bit more concisely, and perhaps more clearly! How does the function look now?

[17]:
# remove the loop over alphas, starts using einsum to help generalize to more dimensions
def new_interpolation_linear_step3(histogramssets, alphasets):
    all_results = []

    allset_all_histo_deltas_up = histogramssets[:, :, 2] - histogramssets[:, :, 1]
    allset_all_histo_deltas_dn = histogramssets[:, :, 1] - histogramssets[:, :, 0]

    for nset, (histoset, alphaset) in enumerate(zip(histogramssets, alphasets)):
        set_result = []

        all_histo_deltas_up = allset_all_histo_deltas_up[nset]
        all_histo_deltas_dn = allset_all_histo_deltas_dn[nset]

        for nh, histo in enumerate(histoset):
            alphas_times_deltas_up = np.einsum(
                'i,j->ij', alphaset, all_histo_deltas_up[nh]
            )
            alphas_times_deltas_dn = np.einsum(
                'i,j->ij', alphaset, all_histo_deltas_dn[nh]
            )
            masks = np.einsum(
                'i,j->ij', alphaset > 0, np.ones_like(all_histo_deltas_dn[nh])
            )

            alpha_deltas = np.where(
                masks, alphas_times_deltas_up, alphas_times_deltas_dn
            )
            set_result.append([histo[1] + d for d in alpha_deltas])

        all_results.append(set_result)
    return all_results

And does the calculation still match?

[18]:
result, (h, a) = compare_fns(interpolation_linear, new_interpolation_linear_step3)
print(result)
True
[19]:
%%timeit
interpolation_linear(h, a)
166 ms ± 11.6 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
[20]:
%%timeit
new_interpolation_linear_step3(h, a)
921 ms ± 133 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Great! Note that we’ve been getting a little bit slower during these steps. It will all pay off in the end when we’re fully tensorized! A lot of the internal steps are overkill with the heavy einstein summation and broadcasting at the moment, especially for how many loops in we are.

Step 4

Now in this step, we will move the einstein summations to the outer loop, so that we’re calculating it once! This is the big step, but a little bit easier because all we’re doing is adding extra dimensions into the calculation. The underlying calculation won’t have changed. At this point, we’ll also rename from i and j to a and b for alpha and bin (as in the bin in the histogram). To continue the notation as well, here’s a summary of the dimensions involved:

  • s will be for the set under consideration (e.g. the modifier)

  • a will be for the alpha variation

  • h will be for the histogram affected by the modifier

  • b will be for the bin of the histogram

So we wish to move the einsum code from

for nset,(histoset, alphaset) in enumerate(zip(histogramssets,alphasets)):
    ...

    for nh,histo in enumerate(histoset):
        alphas_times_deltas_up = np.einsum('i,j->ij',alphaset,all_histo_deltas_up[nh])
            ...

to

all_alphas_times_deltas_up = np.einsum('...',alphaset,all_histo_deltas_up)
for nset,(histoset, alphaset) in enumerate(zip(histogramssets,alphasets)):
    ...

    for nh,histo in enumerate(histoset):
        ...

So how does this new function look?

[21]:
# move the einsums to outer loops to get ready to get rid of all loops
def new_interpolation_linear_step4(histogramssets, alphasets):
    allset_all_histo_deltas_up = histogramssets[:, :, 2] - histogramssets[:, :, 1]
    allset_all_histo_deltas_dn = histogramssets[:, :, 1] - histogramssets[:, :, 0]
    allset_all_histo_nom = histogramssets[:, :, 1]

    allsets_all_histos_alphas_times_deltas_up = np.einsum(
        'sa,shb->shab', alphasets, allset_all_histo_deltas_up
    )
    allsets_all_histos_alphas_times_deltas_dn = np.einsum(
        'sa,shb->shab', alphasets, allset_all_histo_deltas_dn
    )
    allsets_all_histos_masks = np.einsum(
        'sa,s...u->s...au', alphasets > 0, np.ones_like(allset_all_histo_deltas_dn)
    )

    allsets_all_histos_deltas = np.where(
        allsets_all_histos_masks,
        allsets_all_histos_alphas_times_deltas_up,
        allsets_all_histos_alphas_times_deltas_dn,
    )

    all_results = []
    for nset, histoset in enumerate(histogramssets):
        all_histos_deltas = allsets_all_histos_deltas[nset]
        set_result = []
        for nh, histo in enumerate(histoset):
            set_result.append([d + histoset[nh, 1] for d in all_histos_deltas[nh]])
        all_results.append(set_result)
    return all_results

And does the calculation still match?

[22]:
result, (h, a) = compare_fns(interpolation_linear, new_interpolation_linear_step4)
print(result)
True
[23]:
%%timeit
interpolation_linear(h, a)
160 ms ± 5 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
[24]:
%%timeit
new_interpolation_linear_step4(h, a)
119 ms ± 3.19 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

Great! And look at that huge speed up in time already, just from moving the multiple, heavy einstein summation calculations up through the loops. We still have some more optimizing to do as we still have explicit loops in our code. Let’s keep at it, we’re almost there!

Step 5

The hard part is mostly over. We have to now think about the nominal variations. Recall that we were trying to add the nominals to the deltas in order to compute the new value. In practice, we’ll return the delta variation only, but we’ll show you how to get rid of this last loop. In this case, we want to figure out how to change code like

all_results    = []
for nset,histoset in enumerate(histogramssets):
    all_histos_deltas = allsets_all_histos_deltas[nset]
    set_result = []
    for nh,histo in enumerate(histoset):
        set_result.append([d + histoset[nh,1] for d in all_histos_deltas[nh]])
    all_results.append(set_result)

to get rid of that most-nested loop

all_results    = []
for nset,histoset in enumerate(histogramssets):
    # look ma, no more loops inside!

So how does this look?

[25]:
# slowly getting rid of our loops to build the right output tensor -- gotta think about nominals
def new_interpolation_linear_step5(histogramssets, alphasets):
    allset_all_histo_deltas_up = histogramssets[:, :, 2] - histogramssets[:, :, 1]
    allset_all_histo_deltas_dn = histogramssets[:, :, 1] - histogramssets[:, :, 0]
    allset_all_histo_nom = histogramssets[:, :, 1]

    allsets_all_histos_alphas_times_deltas_up = np.einsum(
        'sa,shb->shab', alphasets, allset_all_histo_deltas_up
    )
    allsets_all_histos_alphas_times_deltas_dn = np.einsum(
        'sa,shb->shab', alphasets, allset_all_histo_deltas_dn
    )
    allsets_all_histos_masks = np.einsum(
        'sa,s...u->s...au', alphasets > 0, np.ones_like(allset_all_histo_deltas_dn)
    )

    allsets_all_histos_deltas = np.where(
        allsets_all_histos_masks,
        allsets_all_histos_alphas_times_deltas_up,
        allsets_all_histos_alphas_times_deltas_dn,
    )

    all_results = []

    for nset, (_, alphaset) in enumerate(zip(histogramssets, alphasets)):
        all_histos_deltas = allsets_all_histos_deltas[nset]
        noms = histogramssets[nset, :, 1]

        all_histos_noms_repeated = np.einsum('a,hn->han', np.ones_like(alphaset), noms)

        set_result = all_histos_deltas + all_histos_noms_repeated
        all_results.append(set_result)
    return all_results

And does the calculation still match?

[26]:
result, (h, a) = compare_fns(interpolation_linear, new_interpolation_linear_step5)
print(result)
True
[27]:
%%timeit
interpolation_linear(h, a)
160 ms ± 8.28 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
[28]:
%%timeit
new_interpolation_linear_step5(h, a)
1.57 ms ± 75.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

Fantastic! And look at the speed up. We’re already faster than the for-loop and we’re not even done yet.

Step 6

The final frontier. Also probably the best Star Wars episode. In any case, we have one more for-loop that needs to die in a slab of carbonite. This should be much easier now that you’re more comfortable with tensor broadcasting and einstein summations.

What does the function look like now?

[29]:
def new_interpolation_linear_step6(histogramssets, alphasets):
    allset_allhisto_deltas_up = histogramssets[:, :, 2] - histogramssets[:, :, 1]
    allset_allhisto_deltas_dn = histogramssets[:, :, 1] - histogramssets[:, :, 0]
    allset_allhisto_nom = histogramssets[:, :, 1]

    # x is dummy index

    allsets_allhistos_alphas_times_deltas_up = np.einsum(
        'sa,shb->shab', alphasets, allset_allhisto_deltas_up
    )
    allsets_allhistos_alphas_times_deltas_dn = np.einsum(
        'sa,shb->shab', alphasets, allset_allhisto_deltas_dn
    )
    allsets_allhistos_masks = np.einsum(
        'sa,sxu->sxau',
        np.where(alphasets > 0, np.ones(alphasets.shape), np.zeros(alphasets.shape)),
        np.ones(allset_allhisto_deltas_dn.shape),
    )

    allsets_allhistos_deltas = np.where(
        allsets_allhistos_masks,
        allsets_allhistos_alphas_times_deltas_up,
        allsets_allhistos_alphas_times_deltas_dn,
    )
    allsets_allhistos_noms_repeated = np.einsum(
        'sa,shb->shab', np.ones(alphasets.shape), allset_allhisto_nom
    )
    set_results = allsets_allhistos_deltas + allsets_allhistos_noms_repeated
    return set_results

And does the calculation still match?

[30]:
result, (h, a) = compare_fns(interpolation_linear, new_interpolation_linear_step6)
print(result)
True
[31]:
%%timeit
interpolation_linear(h, a)
156 ms ± 6.29 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
[32]:
%%timeit
new_interpolation_linear_step6(h, a)
468 µs ± 37.1 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

And we’re done tensorizing it. There are some more improvements that could be made to make this interpolation calculation even more robust – but for now we’re done.

Tensorizing the Non-Linear Interpolator

This is very, very similar to what we’ve done for the case of the linear interpolator. As such, we will provide the resulting functions for each step, and you can see how things perform all the way at the bottom. Enjoy and learn at your own pace!

[33]:
def interpolation_nonlinear(histogramssets, alphasets):
    all_results = []
    for histoset, alphaset in zip(histogramssets, alphasets):
        all_results.append([])
        set_result = all_results[-1]
        for histo in histoset:
            set_result.append([])
            histo_result = set_result[-1]
            for alpha in alphaset:
                alpha_result = []
                for down, nom, up in zip(histo[0], histo[1], histo[2]):
                    delta_up = up / nom
                    delta_down = down / nom
                    if alpha > 0:
                        delta = delta_up**alpha
                    else:
                        delta = delta_down ** (-alpha)
                    v = nom * delta
                    alpha_result.append(v)
                histo_result.append(alpha_result)
    return all_results


def new_interpolation_nonlinear_step0(histogramssets, alphasets):
    all_results = []
    for histoset, alphaset in zip(histogramssets, alphasets):
        all_results.append([])
        set_result = all_results[-1]
        for histo in histoset:
            set_result.append([])
            histo_result = set_result[-1]
            for alpha in alphaset:
                alpha_result = []
                for down, nom, up in zip(histo[0], histo[1], histo[2]):
                    delta_up = up / nom
                    delta_down = down / nom
                    delta = np.where(
                        alpha > 0,
                        np.power(delta_up, alpha),
                        np.power(delta_down, np.abs(alpha)),
                    )
                    v = nom * delta
                    alpha_result.append(v)
                histo_result.append(alpha_result)
    return all_results


def new_interpolation_nonlinear_step1(histogramssets, alphasets):
    all_results = []
    for histoset, alphaset in zip(histogramssets, alphasets):
        all_results.append([])
        set_result = all_results[-1]
        for histo in histoset:
            set_result.append([])
            histo_result = set_result[-1]
            for alpha in alphaset:
                alpha_result = []
                deltas_up = np.divide(histo[2], histo[1])
                deltas_down = np.divide(histo[0], histo[1])
                bases = np.where(alpha > 0, deltas_up, deltas_down)
                exponents = np.abs(alpha)
                calc_deltas = np.power(bases, exponents)
                v = histo[1] * calc_deltas
                alpha_result.append(v)
                histo_result.append(alpha_result)
    return all_results


def new_interpolation_nonlinear_step2(histogramssets, alphasets):
    all_results = []

    allset_all_histo_deltas_up = np.divide(
        histogramssets[:, :, 2], histogramssets[:, :, 1]
    )
    allset_all_histo_deltas_dn = np.divide(
        histogramssets[:, :, 0], histogramssets[:, :, 1]
    )

    for nset, (histoset, alphaset) in enumerate(zip(histogramssets, alphasets)):
        set_result = []

        all_histo_deltas_up = allset_all_histo_deltas_up[nset]
        all_histo_deltas_dn = allset_all_histo_deltas_dn[nset]

        for nh, histo in enumerate(histoset):
            alpha_deltas = []
            for alpha in alphaset:
                alpha_result = []
                deltas_up = all_histo_deltas_up[nh]
                deltas_down = all_histo_deltas_dn[nh]
                bases = np.where(alpha > 0, deltas_up, deltas_down)
                exponents = np.abs(alpha)
                calc_deltas = np.power(bases, exponents)
                alpha_deltas.append(calc_deltas)
            set_result.append([histo[1] * d for d in alpha_deltas])
        all_results.append(set_result)
    return all_results


def new_interpolation_nonlinear_step3(histogramssets, alphasets):
    all_results = []

    allset_all_histo_deltas_up = np.divide(
        histogramssets[:, :, 2], histogramssets[:, :, 1]
    )
    allset_all_histo_deltas_dn = np.divide(
        histogramssets[:, :, 0], histogramssets[:, :, 1]
    )

    for nset, (histoset, alphaset) in enumerate(zip(histogramssets, alphasets)):
        set_result = []

        all_histo_deltas_up = allset_all_histo_deltas_up[nset]
        all_histo_deltas_dn = allset_all_histo_deltas_dn[nset]

        for nh, histo in enumerate(histoset):
            # bases and exponents need to have an outer product, to essentially tile or repeat over rows/cols
            bases_up = np.einsum(
                'a,b->ab', np.ones(alphaset.shape), all_histo_deltas_up[nh]
            )
            bases_dn = np.einsum(
                'a,b->ab', np.ones(alphaset.shape), all_histo_deltas_dn[nh]
            )
            exponents = np.einsum(
                'a,b->ab', np.abs(alphaset), np.ones(all_histo_deltas_up[nh].shape)
            )

            masks = np.einsum(
                'a,b->ab', alphaset > 0, np.ones(all_histo_deltas_dn[nh].shape)
            )
            bases = np.where(masks, bases_up, bases_dn)
            alpha_deltas = np.power(bases, exponents)
            set_result.append([histo[1] * d for d in alpha_deltas])

        all_results.append(set_result)
    return all_results


def new_interpolation_nonlinear_step4(histogramssets, alphasets):
    all_results = []

    allset_all_histo_nom = histogramssets[:, :, 1]
    allset_all_histo_deltas_up = np.divide(
        histogramssets[:, :, 2], allset_all_histo_nom
    )
    allset_all_histo_deltas_dn = np.divide(
        histogramssets[:, :, 0], allset_all_histo_nom
    )

    bases_up = np.einsum(
        'sa,shb->shab', np.ones(alphasets.shape), allset_all_histo_deltas_up
    )
    bases_dn = np.einsum(
        'sa,shb->shab', np.ones(alphasets.shape), allset_all_histo_deltas_dn
    )
    exponents = np.einsum(
        'sa,shb->shab', np.abs(alphasets), np.ones(allset_all_histo_deltas_up.shape)
    )

    masks = np.einsum(
        'sa,shb->shab', alphasets > 0, np.ones(allset_all_histo_deltas_up.shape)
    )
    bases = np.where(masks, bases_up, bases_dn)

    allsets_all_histos_deltas = np.power(bases, exponents)

    all_results = []
    for nset, histoset in enumerate(histogramssets):
        all_histos_deltas = allsets_all_histos_deltas[nset]
        set_result = []
        for nh, histo in enumerate(histoset):
            set_result.append([histoset[nh, 1] * d for d in all_histos_deltas[nh]])
        all_results.append(set_result)
    return all_results


def new_interpolation_nonlinear_step5(histogramssets, alphasets):
    all_results = []

    allset_all_histo_nom = histogramssets[:, :, 1]
    allset_all_histo_deltas_up = np.divide(
        histogramssets[:, :, 2], allset_all_histo_nom
    )
    allset_all_histo_deltas_dn = np.divide(
        histogramssets[:, :, 0], allset_all_histo_nom
    )

    bases_up = np.einsum(
        'sa,shb->shab', np.ones(alphasets.shape), allset_all_histo_deltas_up
    )
    bases_dn = np.einsum(
        'sa,shb->shab', np.ones(alphasets.shape), allset_all_histo_deltas_dn
    )
    exponents = np.einsum(
        'sa,shb->shab', np.abs(alphasets), np.ones(allset_all_histo_deltas_up.shape)
    )

    masks = np.einsum(
        'sa,shb->shab', alphasets > 0, np.ones(allset_all_histo_deltas_up.shape)
    )
    bases = np.where(masks, bases_up, bases_dn)

    allsets_all_histos_deltas = np.power(bases, exponents)

    all_results = []
    for nset, (_, alphaset) in enumerate(zip(histogramssets, alphasets)):
        all_histos_deltas = allsets_all_histos_deltas[nset]
        noms = allset_all_histo_nom[nset]
        all_histos_noms_repeated = np.einsum('a,hn->han', np.ones_like(alphaset), noms)
        set_result = all_histos_deltas * all_histos_noms_repeated
        all_results.append(set_result)
    return all_results


def new_interpolation_nonlinear_step6(histogramssets, alphasets):
    all_results = []

    allset_all_histo_nom = histogramssets[:, :, 1]
    allset_all_histo_deltas_up = np.divide(
        histogramssets[:, :, 2], allset_all_histo_nom
    )
    allset_all_histo_deltas_dn = np.divide(
        histogramssets[:, :, 0], allset_all_histo_nom
    )

    bases_up = np.einsum(
        'sa,shb->shab', np.ones(alphasets.shape), allset_all_histo_deltas_up
    )
    bases_dn = np.einsum(
        'sa,shb->shab', np.ones(alphasets.shape), allset_all_histo_deltas_dn
    )
    exponents = np.einsum(
        'sa,shb->shab', np.abs(alphasets), np.ones(allset_all_histo_deltas_up.shape)
    )

    masks = np.einsum(
        'sa,shb->shab', alphasets > 0, np.ones(allset_all_histo_deltas_up.shape)
    )
    bases = np.where(masks, bases_up, bases_dn)

    allsets_all_histos_deltas = np.power(bases, exponents)
    allsets_allhistos_noms_repeated = np.einsum(
        'sa,shb->shab', np.ones(alphasets.shape), allset_all_histo_nom
    )
    set_results = allsets_all_histos_deltas * allsets_allhistos_noms_repeated
    return set_results
[34]:
result, (h, a) = compare_fns(interpolation_nonlinear, new_interpolation_nonlinear_step0)
print(result)
True
[35]:
%%timeit
interpolation_nonlinear(h, a)
149 ms ± 9.45 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
[36]:
%%timeit
new_interpolation_nonlinear_step0(h, a)
527 ms ± 29.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
[37]:
result, (h, a) = compare_fns(interpolation_nonlinear, new_interpolation_nonlinear_step1)
print(result)
True
[38]:
%%timeit
interpolation_nonlinear(h, a)
150 ms ± 5.21 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
[39]:
%%timeit
new_interpolation_nonlinear_step1(h, a)
456 ms ± 17.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
[40]:
result, (h, a) = compare_fns(interpolation_nonlinear, new_interpolation_nonlinear_step2)
print(result)
True
[41]:
%%timeit
interpolation_nonlinear(h, a)
154 ms ± 4.49 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
[42]:
%%timeit
new_interpolation_nonlinear_step2(h, a)
412 ms ± 31 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
[43]:
result, (h, a) = compare_fns(interpolation_nonlinear, new_interpolation_nonlinear_step3)
print(result)
True
[44]:
%%timeit
interpolation_nonlinear(h, a)
145 ms ± 5.15 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
[45]:
%%timeit
new_interpolation_nonlinear_step3(h, a)
1.28 s ± 74.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
[46]:
result, (h, a) = compare_fns(interpolation_nonlinear, new_interpolation_nonlinear_step4)
print(result)
True
[47]:
%%timeit
interpolation_nonlinear(h, a)
147 ms ± 8.4 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
[48]:
%%timeit
new_interpolation_nonlinear_step4(h, a)
120 ms ± 3.06 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
[49]:
result, (h, a) = compare_fns(interpolation_nonlinear, new_interpolation_nonlinear_step5)
print(result)
True
[50]:
%%timeit
interpolation_nonlinear(h, a)
151 ms ± 5.29 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
[51]:
%%timeit
new_interpolation_nonlinear_step5(h, a)
2.65 ms ± 57.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
[52]:
result, (h, a) = compare_fns(interpolation_nonlinear, new_interpolation_nonlinear_step6)
print(result)
True
[53]:
%%timeit
interpolation_nonlinear(h, a)
156 ms ± 3.35 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
[54]:
%%timeit
new_interpolation_nonlinear_step6(h, a)
1.49 ms ± 16 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)