Source code for axsdb.math

"""
Fast interpolation with Numba.

This module provides high-performance interpolation functions implemented
using Numba's ``guvectorize`` decorator. These functions are designed to replace
xarray's interpolation for specific use cases where performance is critical.
"""

from __future__ import annotations

from typing import Literal

import numpy as np
from numba import guvectorize


# Bounds mode constants (used internally by gufunc)
_BOUNDS_FILL = 0
_BOUNDS_CLAMP = 1
_BOUNDS_RAISE = 2


def _make_interp1d_gufunc():  # pragma: no cover
    """
    Create the Numba gufunc for 1D linear interpolation.

    Returns a gufunc with signature ``(n),(n),(m),(),(),()->(m)`` that performs
    linear interpolation.

    The function is created at module load time to ensure JIT compilation
    happens only once.
    """

    @guvectorize(
        [
            "void(float32[:], float32[:], float32[:], int64, float32, float32, float32[:])",
            "void(float64[:], float64[:], float64[:], int64, float64, float64, float64[:])",
        ],
        "(n),(n),(m),(),(),()->(m)",
        nopython=True,
        cache=True,
    )
    def _interp1d_gufunc_impl(x, y, xnew, bounds_mode, fill_lower, fill_upper, out):
        """
        Low-level gufunc for 1D linear interpolation.

        Parameters
        ----------
        x : ndarray
            X coordinates of the data points (must be sorted in ascending order).
            Shape (n,).

        y : ndarray
            Y coordinates of the data points.
            Shape (n,).

        xnew : ndarray
            X coordinates at which to evaluate the interpolation.
            Shape (m,).


        bounds_mode : int
            Bounds handling mode:

            * 0 (fill): use ``fill_lower``/``fill_upper`` for out-of-bounds
              points;
            * 1 (clamp): use nearest boundary value;
            * 2 (raise): mark out-of-bounds with NaN for later validation.

        fill_lower : float
            Fill value for points below ``x[0]`` (only used when bounds_mode=0).

        fill_upper : float
            Fill value for points above ``x[-1]`` (only used when bounds_mode=0).

        out : ndarray
            Output array for interpolated values.
            Shape (m,).
        """
        n = len(x)
        m = len(xnew)

        x_min = x[0]
        x_max = x[n - 1]

        for i in range(m):
            xi = xnew[i]

            # Handle NaN in query point
            if np.isnan(xi):
                out[i] = np.nan
                continue

            # Handle out-of-bounds: below minimum
            if xi < x_min:
                if bounds_mode == _BOUNDS_FILL:
                    out[i] = fill_lower
                elif bounds_mode == _BOUNDS_CLAMP:
                    out[i] = y[0]
                else:  # bounds_mode == _BOUNDS_RAISE
                    out[i] = np.nan  # Mark for validation in wrapper
                continue

            # Handle out-of-bounds: above maximum
            if xi > x_max:
                if bounds_mode == _BOUNDS_FILL:
                    out[i] = fill_upper
                elif bounds_mode == _BOUNDS_CLAMP:
                    out[i] = y[n - 1]
                else:  # bounds_mode == _BOUNDS_RAISE
                    out[i] = np.nan  # Mark for validation in wrapper
                continue

            # Binary search to find the interval [x[left], x[right]]
            left = 0
            right = n - 1

            while right - left > 1:
                mid = (left + right) // 2
                if x[mid] <= xi:
                    left = mid
                else:
                    right = mid

            # Handle exact match at boundary to avoid numerical issues
            if xi == x[left]:
                out[i] = y[left]
                continue
            if xi == x[right]:
                out[i] = y[right]
                continue

            # Linear interpolation: y0 + t * (y1 - y0) where t = (xi - x0) / (x1 - x0)
            x0 = x[left]
            x1 = x[right]
            y0 = y[left]
            y1 = y[right]

            t = (xi - x0) / (x1 - x0)
            out[i] = y0 + t * (y1 - y0)

    return _interp1d_gufunc_impl


def _make_lerp_gufunc():  # pragma: no cover
    """
    Create the numba gufunc for lerp with precomputed indices and weights.

    Returns a gufunc with signature ``(n),(m),(m)->(m)``.  The search step
    is skipped entirely; the caller must supply the left-index and weight
    arrays produced by :func:`lerp_indices`.
    """

    @guvectorize(
        [
            "void(float32[:], float32[:], float32[:], float32[:])",
            "void(float64[:], float64[:], float64[:], float64[:])",
        ],
        "(n),(m),(m)->(m)",
        nopython=True,
        cache=True,
    )
    def _lerp_gufunc_impl(y, indices, weights, out):
        m = len(indices)
        for i in range(m):
            left = int(indices[i])
            t = weights[i]
            out[i] = y[left] + t * (y[left + 1] - y[left])

    return _lerp_gufunc_impl


# Create gufuncs at module load time
_interp1d_gufunc = _make_interp1d_gufunc()
_lerp_gufunc = _make_lerp_gufunc()


[docs] def interp1d( x: np.ndarray, y: np.ndarray, xnew: np.ndarray, bounds: Literal["fill", "clamp", "raise"] = "fill", fill_value: float | tuple[float, float] = np.nan, ) -> np.ndarray: """ Fast 1D linear interpolation. This function provides high-performance linear interpolation that broadcasts over leading dimensions. It powers a drop-in replacement for cases where xarray's interpolation is too slow. Parameters ---------- x : array-like X coordinates of the data points. Must be sorted in ascending order along the last axis. Results are undefined for unsorted x. Shape (..., n). y : array-like Y coordinates of the data points. Shape (..., n). xnew : array-like X coordinates at which to evaluate the interpolation. Shape (..., m). bounds : {"fill", "clamp", "raise"}, default: "fill" How to handle out-of-bounds query points: * ``"fill"``: use ``fill_value`` for points outside the data range. * ``"clamp"``: use the nearest boundary value (``y[0]`` or ``y[-1]``). * ``"raise"``: raise a ValueError if any query point is out of bounds. fill_value : float or tuple of (float, float), default: np.nan Value(s) to use for out-of-bounds points when ``bounds="fill"``: * if a single float, use for both lower and upper bounds; * if a 2-tuple, use (``fill_lower``, ``fill_upper``). Returns ------- ndarray Interpolated values at the query points. The output shape is determined by numpy broadcasting rules applied to x, y, and xnew. Shape (..., m). Raises ------ ValueError * If ``bounds="raise"`` and any query point is outside the data range. * If ``bounds`` is not one of "fill", "clamp", or "raise". * If ``fill_value`` is a tuple with length != 2. Notes ----- * The implementation uses a Numba gufunc with signature ``(n),(n),(m)->(m)`` for the core interpolation, enabling efficient broadcasting over arbitrary leading dimensions. * The function assumes ``x`` is sorted in ascending order along the last axis. Results are undefined if this assumption is violated. * NaN values in ``xnew`` are passed through (output will be NaN). Examples -------- Basic interpolation: >>> x = np.array([0.0, 1.0, 2.0, 3.0]) >>> y = np.array([0.0, 1.0, 4.0, 9.0]) >>> xnew = np.array([0.5, 1.5, 2.5]) >>> interp1d(x, y, xnew) array([0.5, 2.5, 6.5]) With fill values for out-of-bounds: >>> xnew = np.array([-1.0, 1.5, 5.0]) >>> interp1d(x, y, xnew, bounds="fill", fill_value=(-999.0, 999.0)) array([-999. , 2.5, 999. ]) Clamping to boundary values: >>> interp1d(x, y, xnew, bounds="clamp") array([0. , 2.5, 9. ]) Broadcasting over multiple curves: >>> x = np.array([0.0, 1.0, 2.0]) >>> y = np.array([[0.0, 1.0, 2.0], # Linear ... [0.0, 1.0, 4.0]]) # Quadratic >>> xnew = np.array([0.5, 1.5]) >>> interp1d(x, y, xnew) array([[0.5, 1.5], [0.5, 2.5]]) """ # Convert inputs to numpy arrays x = np.asarray(x) y = np.asarray(y) xnew = np.asarray(xnew) # Validate bounds mode if bounds == "fill": bounds_mode = _BOUNDS_FILL elif bounds == "clamp": bounds_mode = _BOUNDS_CLAMP elif bounds == "raise": bounds_mode = _BOUNDS_RAISE else: raise ValueError( f"Invalid bounds mode: {bounds!r}. Must be one of 'fill', 'clamp', 'raise'." ) # Parse fill_value if isinstance(fill_value, tuple): if len(fill_value) != 2: raise ValueError( f"fill_value tuple must have exactly 2 elements, got {len(fill_value)}" ) fill_lower, fill_upper = fill_value else: fill_lower = fill_upper = fill_value # Ensure float dtype (convert integers to float64) if not np.issubdtype(x.dtype, np.floating): x = x.astype(np.float64) if not np.issubdtype(y.dtype, np.floating): y = y.astype(np.float64) if not np.issubdtype(xnew.dtype, np.floating): xnew = xnew.astype(np.float64) # Promote to common dtype common_dtype = np.result_type(x, y, xnew) if x.dtype != common_dtype: x = x.astype(common_dtype) if y.dtype != common_dtype: y = y.astype(common_dtype) if xnew.dtype != common_dtype: xnew = xnew.astype(common_dtype) # Convert fill values to the common dtype fill_lower = common_dtype.type(fill_lower) fill_upper = common_dtype.type(fill_upper) # Pre-validate bounds="raise" mode if bounds == "raise": # Get valid (non-NaN) query points xnew_flat = xnew.ravel() xnew_valid = xnew_flat[~np.isnan(xnew_flat)] if xnew_valid.size > 0: # Get overall bounds of x (simplest and most robust approach) x_min_val = np.min(x) x_max_val = np.max(x) # Check for violations min_query = np.min(xnew_valid) max_query = np.max(xnew_valid) below = min_query < x_min_val above = max_query > x_max_val if below or above: # Build error message msg_parts = ["Query points out of bounds."] if below: delta_low = x_min_val - min_query msg_parts.append(f"Below lower bound by up to {delta_low:.6g}") if above: delta_high = max_query - x_max_val msg_parts.append(f"Above upper bound by up to {delta_high:.6g}") raise ValueError(" ".join(msg_parts)) # Call the gufunc result = _interp1d_gufunc(x, y, xnew, np.int64(bounds_mode), fill_lower, fill_upper) return result
[docs] def lerp_indices( x: np.ndarray, xnew: np.ndarray, bounds: Literal["fill", "clamp", "raise"] = "fill", ) -> tuple[np.ndarray, np.ndarray]: """ Precompute left-indices and interpolation weights for linear interpolation. When the same query points ``xnew`` will be applied to many different ``y`` arrays sharing the same ``x`` grid, it is far cheaper to run the binary search once here and then call :func:`lerp` for each ``y``. That function skips the search entirely and executes only the ``y[left] + t*(y[left+1] - y[left])`` step. Parameters ---------- x : ndarray Sorted coordinate grid (1-D). Shape (n,). xnew : ndarray Query points (1-D). Shape (m,). bounds : {"fill", "clamp", "raise"}, default: "fill" Out-of-bounds handling, same semantics as :func:`interp1d`. * ``"fill"``: out-of-bounds indices are set to 0 with weight NaN so that :func:`lerp` will produce NaN there. The caller can replace those NaNs after the fact if a different fill value is needed. * ``"clamp"``: out-of-bounds queries are clamped to the nearest boundary index with weight 0 (reproducing ``y[0]`` or ``y[-1]``). * ``"raise"``: raises immediately if any query is out of bounds. Returns ------- indices : ndarray Left-bin indices as floats (required by the gufunc signature). Shape (m,), dtype float64. weights : ndarray Fractional position within each bin: ``t = (xnew - x[i]) / (x[i+1] - x[i])``. Shape (m,), dtype float64. Raises ------ ValueError If ``bounds="raise"`` and any query point is outside ``[x[0], x[-1]]``. """ x = np.asarray(x, dtype=np.float64) xnew = np.asarray(xnew, dtype=np.float64) n = len(x) # searchsorted gives insertion point; left-bin index = insertion - 1 raw = np.searchsorted(x, xnew, side="right") - 1 # shape (m,) if bounds == "raise": # Check for out-of-bounds points x_min_val = x[0] x_max_val = x[-1] min_query = xnew.min() max_query = xnew.max() below = min_query < x_min_val above = max_query > x_max_val if below or above: # Build informative error message msg_parts = ["Query points out of bounds."] if below: delta_low = x_min_val - min_query msg_parts.append(f"Below lower bound by up to {delta_low:.6g}") if above: delta_high = max_query - x_max_val msg_parts.append(f"Above upper bound by up to {delta_high:.6g}") raise ValueError(" ".join(msg_parts)) # Clamp indices to valid bin range [0, n-2] indices = np.clip(raw, 0, n - 2) # Compute weights x0 = x[indices] x1 = x[indices + 1] weights = (xnew - x0) / (x1 - x0) if bounds == "clamp": # For clamping, we need special handling for boundary points: # - Points below x[0]: index=0, weight=0 -> y[0] + 0*(y[1]-y[0]) = y[0] # - Points above x[-1]: index=n-2, weight=1 -> y[n-2] + 1*(y[n-1]-y[n-2]) = y[n-1] # We use <= and >= (not < and >) to avoid numerical issues with exact # boundary matches where floating-point arithmetic might produce tiny # non-zero weights. weights = np.where(xnew <= x[0], 0.0, weights) weights = np.where(xnew >= x[-1], 1.0, weights) elif bounds == "fill": # Mark out-of-bounds with NaN weight so lerp produces NaN oob = (xnew < x[0]) | (xnew > x[-1]) weights = np.where(oob, np.nan, weights) return indices.astype(np.float64), weights
[docs] def lerp(y: np.ndarray, indices: np.ndarray, weights: np.ndarray) -> np.ndarray: """ Linear interpolation using precomputed indices and weights. This is the fast inner loop for the case where many ``y`` arrays share the same ``x`` grid and query points. The binary search is done once via :func:`lerp_indices`; this function executes only the linear combination ``y[i] + t * (y[i+1] - y[i])``. Parameters ---------- y : ndarray Data values. The last axis must correspond to the ``x`` grid used in :func:`lerp_indices`. Broadcasting over leading dimensions is handled by the underlying gufunc. Shape (..., n). indices : ndarray Left-bin indices from :func:`lerp_indices`. **IMPORTANT**: Indices must be in the range ``[0, n-2]`` where ``n = y.shape[-1]``. This invariant is enforced by :func:`lerp_indices` but is not validated here for performance reasons. Shape (m,). weights : ndarray Interpolation weights from :func:`lerp_indices`. Shape (m,). Returns ------- ndarray Interpolated values. NaN weights (from ``bounds="fill"``) propagate as NaN in the output. Shape (..., m). Notes ----- This function does not perform bounds checking on ``indices`` for performance. The caller must ensure indices are valid (in ``[0, n-2]``). Using :func:`lerp_indices` guarantees this invariant. """ y = np.asarray(y, dtype=np.float64) indices = np.asarray(indices, dtype=np.float64) weights = np.asarray(weights, dtype=np.float64) return _lerp_gufunc(y, indices, weights)