Skip to content

Commit

Permalink
Merge pull request #231 from jorenham/bugfix/cache
Browse files Browse the repository at this point in the history
Fix several `l_weights` caching  issues
  • Loading branch information
jorenham authored May 22, 2024
2 parents 0b0e611 + bc335ac commit b13fe23
Show file tree
Hide file tree
Showing 3 changed files with 194 additions and 83 deletions.
116 changes: 61 additions & 55 deletions lmo/_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
round0,
)
from .linalg import ir_pascal, sandwich, sh_legendre, trim_matrix
from .typing import np as lnpt
from .typing.compat import TypeVar


Expand All @@ -30,7 +31,6 @@
AnyOrderND,
AnyTrim,
LMomentOptions,
np as lnpt,
)
from .typing.compat import Unpack

Expand Down Expand Up @@ -66,13 +66,13 @@
_Vectorized: TypeAlias = _T_float | npt.NDArray[_T_float]
_Floating: TypeAlias = np.floating[Any]

_L_WEIGHTS_CACHE: Final[
dict[
# (n, s, t)
tuple[int, int, int] | tuple[int, float, float],
lnpt.Array[tuple[int, int], _Floating],
]
] = {}
# (n, s, t)
_CacheKey: TypeAlias = tuple[int, int, int] | tuple[int, float, float]
# `r: _T_order >= 4`
_CacheArray: TypeAlias = lnpt.Array[tuple[_T_order, _T_size], np.longdouble]
_Cache: TypeAlias = dict[_CacheKey, _CacheArray[Any, Any]]

_CACHE: Final[_Cache] = {}


def _l_weights_pwm(
Expand All @@ -91,7 +91,15 @@ def _l_weights_pwm(
sh_legendre(r0, dtype=np.int64 if r0 < 29 else dtype),
pwm_beta.weights(r0, n, dtype=dtype),
)
return np.matmul(trim_matrix(r, trim, dtype=dtype), w0) if s or t else w0
wr = np.matmul(trim_matrix(r, trim, dtype=dtype), w0) if s or t else w0

# ensure that the trimmed ends are 0
if s:
wr[:, :s] = 0
if t:
wr[:, -t:] = 0

return wr


def _l_weights_ostat(
Expand Down Expand Up @@ -125,7 +133,7 @@ def _l_weights_ostat(


def l_weights(
r: _T_order,
r_max: _T_order,
n: _T_size,
/,
trim: AnyTrim = 0,
Expand Down Expand Up @@ -170,8 +178,15 @@ def l_weights(
observation vector(s) of size $n$), into (an unbiased estimate of) the
*generalized trimmed L-moments*, with orders $\le r$.
Args:
r_max: The amount L-moment orders.
n: The number of samples.
trim: A scalar or 2-tuple with the trim orders. Defaults to 0.
dtype: The datatype of the returned weight matrix.
cache: Whether to cache the weights, defaults to `False`.
Returns:
P_r: 2-D array of shape `(r, n)`.
P_r: 2-D array of shape `(r_max, n)`, readonly if `cache=True`
Examples:
>>> import lmo
Expand All @@ -186,56 +201,47 @@ def l_weights(
- [J.R.M. Hosking (2007) - Some theory and practical uses of trimmed
L-moments](https://doi.org/10.1016/j.jspi.2006.12.002)
"""
if r == 0:
return np.empty((r, n), dtype=dtype)

match clean_trim(trim):
case s, t if s < 0 or t < 0:
msg = f'trim orders must be >=0, got {trim}'
raise ValueError(msg)
case s, t:
pass
case _: # type: ignore [reportUneccessaryComparison]
msg = (
f'trim must be a tuple with two non-negative ints or floats, '
f'got {trim!r}'
)
raise TypeError(msg)
if r_max < 0:
msg = f'r must be non-negative, got {r_max}'
raise ValueError(msg)

if r_max == 0:
return np.empty((0, n), dtype=dtype)

s, t = clean_trim(trim)

if (n_min := r_max + s + t) > n:
msg = f'expected n >= r + s + t, got {n} < {n_min}'
raise ValueError(msg)

# manual cache lookup, only if cache=False (for testability)
# e.g. `functools.cache` would be inefficient for e.g. r=3 with cached r=4
cache_key = n, s, t
if (
cache
and cache_key in _L_WEIGHTS_CACHE
and (w := _L_WEIGHTS_CACHE[cache_key]).shape[0] <= r
):
if w.shape[0] < r:
w = w[:r]

# ignore if r is larger that what's cached
if w.shape[0] == r:
assert w.shape == (r, n)
return w.astype(dtype)

if r + s + t <= 24 and isinstance(s, int) and isinstance(t, int):
w = _l_weights_pwm(r, n, (s, t), dtype=dtype or np.float64)

# ensure that the trimmed ends are 0
if s:
w[:, :s] = 0
if t:
w[:, -t:] = 0
key = n, s, t
if (w := _CACHE.get(key)) is not None and w.shape[0] >= r_max:
pass
else:
w = _l_weights_ostat(r, n, (s, t), dtype=dtype or np.float64)
# when caching, use at least 4 orders, to avoid cache misses
_r_max = 4 if cache and r_max < 4 else r_max

# use the highest available precision when caching (96 or 128 bits,
# depending on the platform)
_dtype = np.longdouble if cache else dtype

if r_max + s + t <= 24 and isinstance(s, int) and isinstance(t, int):
w = _l_weights_pwm(_r_max, n, (s, t), dtype=_dtype)
else:
w = _l_weights_ostat(_r_max, n, (s, t), dtype=_dtype)

if cache:
w.setflags(write=False)
# be wary of a potential race condition
if key not in _CACHE or w.shape[0] >= _CACHE[key].shape[0]:
_CACHE[key] = w

if cache:
# the pyright error here is due to the fact that the first type param
# of `np.ndarray` is invariant (which is incorrect), instead of
# being covariant
_L_WEIGHTS_CACHE[cache_key] = w # pyright: ignore[reportArgumentType]
if w.shape[0] > r_max:
w = w[:r_max]

return w
return w.astype(dtype, casting='same_kind', copy=False)


@overload
Expand Down
25 changes: 25 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# ruff: noqa: SLF001
# pyright: reportPrivateUsage=false
import contextlib
from collections.abc import Iterator

import pytest

from lmo import _lm


@contextlib.contextmanager
def tmp_cache() -> Iterator[_lm._Cache]:
cache_tmp: _lm._Cache = {}
cache_old, _lm._CACHE = _lm._CACHE, cache_tmp
try:
yield cache_tmp
finally:
_lm._CACHE = cache_old


@pytest.fixture(name='tmp_cache')
def tmp_cache_fixture():
with tmp_cache() as cache:
assert not cache
yield cache
136 changes: 108 additions & 28 deletions tests/test_weights.py
Original file line number Diff line number Diff line change
@@ -1,78 +1,158 @@
import functools
from typing import Any

import numpy as np
import pytest
from hypothesis import (
given,
strategies as st,
)
from hypothesis.extra import numpy as hnp
from numpy.testing import (
assert_allclose as _assert_allclose,
assert_array_equal,
)

from lmo import l_weights
from .conftest import tmp_cache

from lmo._lm import l_weights

# matches np.allclose
assert_allclose = functools.partial(_assert_allclose, rtol=1e-5, atol=1e-8)


MAX_N = 1 << 10
MAX_R = 8
MAX_T = 4
MIN_N = MAX_R + MAX_T * 2 + 1
MAX_N = 1 << 8


st_n = st.integers(MAX_R + MAX_T * 2 + 1, MAX_N)
st_n = st.integers(MIN_N, MAX_N)
st_r = st.integers(1, MAX_R)

st_t_f = st.floats(0, MAX_T, exclude_min=True)
st_t_i = st.integers(1, MAX_T)
st_t_i0 = st.integers(0, MAX_T)
st_i_eq0 = st.just(0)
st_i_ge0 = st.integers(0, MAX_T)
st_i_gt0 = st.integers(1, MAX_T)

st_i2_eq0 = st.tuples(st.just(0), st.just(0))
st_i2_ge0 = st.tuples(st.integers(0, MAX_T), st.integers(0, MAX_T))
st_i2_gt0 = st.tuples(st.integers(1, MAX_T), st.integers(1, MAX_T))

st_trim_i = st.tuples(st_t_i, st_t_i)
st_trim_i0 = st.tuples(st_t_i0, st_t_i0)
st_i12_eq0 = st_i_eq0 | st_i2_eq0
st_i12_ge0 = st_i_ge0 | st_i2_ge0
st_i12_gt0 = st_i_gt0 | st_i2_gt0

st_floating = hnp.floating_dtypes()

@given(n=st_n, r=st_r, trim0=st.just((0, 0)))
def test_l_weights_alias(n, r, trim0):

@given(n=st_n, trim=st_i12_eq0)
def test_empty(n: int, trim: int | tuple[int, int]):
w = l_weights(0, n, trim)
assert w.shape == (0, n)


@given(n=st_n, r=st_r, trim=st_i12_eq0)
def test_untrimmed(n: int, r: int, trim: int | tuple[int, int]):
w_l = l_weights(r, n)
w_tl = l_weights(r, n, trim0)
w_tl = l_weights(r, n, trim)

assert np.array_equal(w_l, w_tl)
assert_array_equal(w_l, w_tl)


@given(n=st_n, r=st_r, trim=st_trim_i0)
def test_l_weights_basic(n, r, trim):
@given(n=st_n, r=st_r, trim=st_i12_ge0)
def test_default(n: int, r: int, trim: int | tuple[int, int]):
w = l_weights(r, n, trim)

assert w.shape == (r, n)
assert np.all(np.isfinite(n))
assert np.all(np.isfinite(w))
assert w.dtype.type is np.float64


# symmetries only apply for symmetric trimming, for obvious reasons
@given(n=st_n, t=st_t_i0)
def test_l_weights_symmetry(n, t):
@given(n=st_n, r=st_r, trim=st_i12_ge0, dtype=st_floating)
def test_dtype(
n: int,
r: int,
trim: int | tuple[int, int],
dtype: np.dtype[np.floating[Any]],
):
w = l_weights(r, n, trim, dtype=dtype)

assert np.all(np.isfinite(w))
assert w.dtype.type is dtype.type


@given(n=st_n, t=st_i_ge0)
def test_symmetry(n: int, t: int):
w = l_weights(MAX_R, n, (t, t))

w_evn_lhs, w_evn_rhs = w[::2], w[::2, ::-1]
assert np.allclose(w_evn_lhs, w_evn_rhs)
assert_allclose(w_evn_lhs, w_evn_rhs)

w_odd_lhs, w_odd_rhs = w[1::2], w[1::2, ::-1]
assert np.allclose(w_odd_lhs, -w_odd_rhs)
assert_allclose(w_odd_lhs, -w_odd_rhs)


def test_l_weights_symmetry_large_even_r():
w = l_weights(16, MAX_N * 2)

w_evn_lhs, w_evn_rhs = w[::2], w[::2, ::-1]
assert np.allclose(w_evn_lhs, w_evn_rhs)
assert_allclose(w_evn_lhs, w_evn_rhs)


@given(n=st_n, r=st_r, trim=st_trim_i)
def test_l_weights_trim(n, r, trim):
@given(n=st_n, r=st_r, trim=st_i2_gt0)
def test_trim(n: int, r: int, trim: tuple[int, int]):
w = l_weights(r, n, trim)

tl, tr = trim
assert tl > 0
assert tr > 0

assert np.allclose(w[:, :tl], 0)
assert np.allclose(w[:, n - tr :], 0)
assert_allclose(w[:, :tl], 0)
assert_allclose(w[:, n - tr :], 0)


@given(n=st_n, r=st.integers(2, MAX_R), trim=st_trim_i0)
def test_tl_weights_sum(n, r, trim):
@given(n=st_n, r=st.integers(2, MAX_R), trim=st_i12_ge0)
def test_sum(n: int, r: int, trim: int | tuple[int, int]):
w = l_weights(r, n, trim)
w_sum = w.sum(axis=-1)

assert np.allclose(w_sum, np.eye(r, 1).ravel())
assert_allclose(w_sum, np.eye(r, 1).ravel())


@given(n=st_n, r=st.integers(4, MAX_R), trim=st_i12_ge0)
def test_uncached(n: int, r: int, trim: int | tuple[int, int]):
with tmp_cache() as cache:
w0 = l_weights(r, n, trim, cache=False)
w1 = l_weights(r, n, trim, cache=False)

assert not cache
assert w0 is not w1
assert_array_equal(w0, w1)


@given(n=st_n, r=st.integers(4, MAX_R), trim=st_i12_ge0)
def test_cached(n: int, r: int, trim: int | tuple[int, int]):
cache_key = (n, *trim) if isinstance(trim, tuple) else (n, trim, trim)

with tmp_cache() as cache:
assert cache_key not in cache

w0 = l_weights(r, n, trim, cache=True, dtype=np.longdouble)
assert cache_key in cache
w0_cached = cache[cache_key]

# cached weights should be readonly
w0_orig = w0[0, 0]
with pytest.raises(
ValueError,
match='assignment destination is read-only',
):
w0[0, 0] = w0_orig + 1
assert w0[0, 0] == w0_orig

w1 = l_weights(r, n, trim, cache=True, dtype=np.longdouble)
w1_cached = cache[cache_key]
assert w0_cached is w1_cached

# this requires `r>=4`, `dtype=np.longdouble` and `r == r_cached`
assert w0 is w1

0 comments on commit b13fe23

Please sign in to comment.