Skip to content

Commit

Permalink
Merge pull request #348 from jorenham/l_moment-from-rv
Browse files Browse the repository at this point in the history
preparations for the decoupling of the L-moment methods and the scipy distributions
  • Loading branch information
jorenham authored Nov 19, 2024
2 parents 3ba3b39 + ab11075 commit 39520bb
Show file tree
Hide file tree
Showing 11 changed files with 249 additions and 308 deletions.
24 changes: 10 additions & 14 deletions lmo/_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from . import ostats, pwm_beta
from ._utils import (
clean_order,
clean_orders,
clean_trim,
ensure_axis_at,
l_stats_orders,
Expand All @@ -22,7 +21,6 @@
sort_maybe,
)
from .linalg import ir_pascal, sandwich, sh_legendre, trim_matrix
from .typing import AnyOrder, AnyOrderND

if sys.version_info >= (3, 13):
from typing import TypeVar
Expand All @@ -34,6 +32,7 @@

import lmo.typing as lmt
import lmo.typing.np as lnpt
from .typing import AnyOrder, AnyOrderND


__all__ = (
Expand Down Expand Up @@ -248,9 +247,9 @@ def l_moment(
a: lnpt.AnyArrayFloat,
r: AnyOrder,
/,
trim: lmt.AnyTrim = ...,
trim: lmt.AnyTrim = 0,
*,
axis: None = ...,
axis: None = None,
dtype: _DType[_SCT_f] = np.float64,
**kwds: Unpack[lmt.LMomentOptions],
) -> _SCT_f: ...
Expand All @@ -259,7 +258,7 @@ def l_moment(
a: lnpt.AnyMatrixFloat | lnpt.AnyTensorFloat,
r: AnyOrder | AnyOrderND,
/,
trim: lmt.AnyTrim = ...,
trim: lmt.AnyTrim = 0,
*,
axis: int,
dtype: _DType[_SCT_f] = np.float64,
Expand All @@ -270,7 +269,7 @@ def l_moment(
a: lnpt.AnyVectorFloat,
r: AnyOrder,
/,
trim: lmt.AnyTrim = ...,
trim: lmt.AnyTrim = 0,
*,
axis: int,
dtype: _DType[_SCT_f] = np.float64,
Expand All @@ -281,9 +280,9 @@ def l_moment(
a: lnpt.AnyArrayFloat,
r: AnyOrderND,
/,
trim: lmt.AnyTrim = ...,
trim: lmt.AnyTrim = 0,
*,
axis: int | None = ...,
axis: int | None = None,
dtype: _DType[_SCT_f] = np.float64,
**kwds: Unpack[lmt.LMomentOptions],
) -> onpt.Array[Any, _SCT_f]: ...
Expand Down Expand Up @@ -410,19 +409,16 @@ def l_moment(
x_k = ensure_axis_at(x_k, axis, -1)
n = x_k.shape[-1]

if np.isscalar(r):
_r = np.array(clean_order(cast(AnyOrder, r)))
else:
_r = clean_orders(cast(AnyOrderND, r))
r_min, r_max = np.min(_r), int(np.max(_r))
_r = clean_order(r)
r_min, r_max = np.min(_r), np.max(_r)

# TODO @jorenham: nan handling, see:
# https://github.com/jorenham/Lmo/issues/70

(s, t) = st = clean_trim(trim)

# ensure that any inf's (not nan's) are properly trimmed
if (s or t) and isinstance(s, int):
if isinstance(s, int):
if s:
x_k[..., :s] = np.nan_to_num(x_k[..., :s], nan=np.nan)
if t:
Expand Down
42 changes: 26 additions & 16 deletions lmo/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,8 @@
import numpy.typing as npt

if sys.version_info >= (3, 13):
from typing import LiteralString, TypeVar
from typing import TypeVar
else:
from typing import LiteralString

from typing_extensions import TypeVar

if TYPE_CHECKING:
Expand All @@ -37,15 +35,15 @@

_SCT = TypeVar("_SCT", bound=np.generic)
_SCT_uifc = TypeVar("_SCT_uifc", bound="lnpt.Number")
_SCT_ui = TypeVar("_SCT_ui", bound="lnpt.Int", default=np.int_)
_SCT_f = TypeVar("_SCT_f", bound="lnpt.Float", default=np.float64)
_SCT_ui = TypeVar("_SCT_ui", bound=np.integer[Any], default=np.intp)
_SCT_f = TypeVar("_SCT_f", bound=np.floating[Any], default=np.float64)

_DT_f = TypeVar("_DT_f", bound=np.dtype["lnpt.Float"])
_AT_f = TypeVar("_AT_f", bound="npt.NDArray[lnpt.Float] | lnpt.Float")
_DT_f = TypeVar("_DT_f", bound=np.dtype[np.floating[Any]])
_AT_f = TypeVar("_AT_f", bound=np.floating[Any] | npt.NDArray[np.floating[Any]])

_SizeT = TypeVar("_SizeT", bound=int)

_ShapeT = TypeVar("_ShapeT", bound="onpt.AtLeast0D")
_ShapeT = TypeVar("_ShapeT", bound=tuple[int, ...])
_ShapeT1 = TypeVar("_ShapeT1", bound="onpt.AtLeast1D")
_ShapeT2 = TypeVar("_ShapeT2", bound="onpt.AtLeast2D")

Expand Down Expand Up @@ -261,27 +259,39 @@ def ordered( # noqa: C901
return x_kk


@overload
def clean_order(r: lmt.AnyOrder, /, name: str = "r", rmin: int = 0) -> int: ...
@overload
def clean_order(
r: lmt.AnyOrderND,
/,
name: str = "r",
rmin: int = 0,
) -> npt.NDArray[np.intp]: ...
def clean_order(
r: lmt.AnyOrder,
r: lmt.AnyOrder | lmt.AnyOrderND,
/,
name: LiteralString = "r",
name: str = "r",
rmin: int = 0,
) -> int:
) -> int | npt.NDArray[np.intp]:
"""Validates and cleans an single (L-)moment order."""
if (_r := int(r)) < rmin:
msg = f"expected {name} >= {rmin}, got {_r}"
if not isinstance(r, int | np.integer):
return clean_orders(r, name=name, rmin=rmin)

if r < rmin:
msg = f"expected {name} >= {rmin}, got {r}"
raise TypeError(msg)

return _r
return int(r)


def clean_orders(
r: lmt.AnyOrderND,
/,
name: str = "r",
rmin: int = 0,
dtype: _DType[_SCT_ui] = np.int_,
) -> onpt.Array[Any, _SCT_ui]:
dtype: _DType[_SCT_ui] = np.intp,
) -> npt.NDArray[_SCT_ui]:
"""Validates and cleans an array-like of (L-)moment orders."""
_r = np.asarray_chkfinite(r, dtype=dtype)

Expand Down
132 changes: 69 additions & 63 deletions lmo/contrib/scipy_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,11 @@
moments_to_ratio,
round0,
)
from lmo.distributions._lm import get_lm_func, has_lm_func
from lmo.distributions._lm import get_lm_func, has_lm_func, prefers_ppf
from lmo.theoretical import (
l_moment_cov_from_cdf,
l_moment_from_cdf,
l_moment_from_ppf,
l_moment_influence_from_cdf,
l_ratio_influence_from_cdf,
l_stats_cov_from_cdf,
Expand Down Expand Up @@ -150,61 +151,6 @@ class l_rv_generic(PatchClass):
ppf: _Fn1
std: Callable[..., float]

def _get_xxf(
self,
*args: Any,
loc: float = 0,
scale: float = 1,
) -> _Tuple2[_Fn1]:
assert scale > 0

_cdf, _ppf = self._cdf, self._ppf

def cdf(x: _T_x, /) -> _T_x:
return _cdf(np.array([(x - loc) / scale], dtype=float), *args)[0]

def ppf(q: _T_x, /) -> _T_x:
return _ppf(np.array([q], dtype=float), *args)[0] * scale + loc

return cdf, ppf

def _l_moment(
self,
r: npt.NDArray[np.intp],
*args: Any,
trim: _Tuple2[int] | _Tuple2[float] = (0, 0),
quad_opts: lspt.QuadOptions | None = None,
) -> _ArrF8:
"""
Population L-moments of the standard distribution (i.e. assuming
`loc=0` and `scale=1`).
Todo:
- Sparse caching; key as `(self, args, r, trim)`, using a
priority queue. Prefer small `r` and `sum(trim)`, skip fractional
trim.
"""
name = self.name
if quad_opts is None and has_lm_func(name):
with contextlib.suppress(NotImplementedError):
return get_lm_func(name)(r, trim[0], trim[1], *args)

cdf, ppf = self._get_xxf(*args)

# TODO: use ppf when appropriate (e.g. genextreme, tukeylambda, kappa4)
with np.errstate(over="ignore", under="ignore"):
lmbda_r = l_moment_from_cdf(
cdf,
r,
trim=trim,
support=self._get_support(*args),
ppf=ppf,
quad_opts=quad_opts,
)

# re-wrap scalars in 0-d arrays (lmo.theoretical unpacks them)
return np.asarray(lmbda_r)

@np.errstate(divide="ignore")
def _logqdf(self, u: _ArrF8, *args: Any) -> _ArrF8:
"""Overridable log quantile distribution function (QDF)."""
Expand Down Expand Up @@ -337,7 +283,7 @@ def l_moment(
return np.full(_r.shape, np.nan)[()]

# L-moments of the standard distribution (loc=0, scale=scale0)
l0_r = self._l_moment(_r, *shapes, trim=_trim, quad_opts=quad_opts)
l0_r = _l_moment(self, _r, *shapes, trim=_trim, quad_opts=quad_opts)

# shift (by loc) and scale
shift_r = loc * (_r == 1)
Expand Down Expand Up @@ -688,7 +634,7 @@ def l_moments_cov(
self._parse_args(*args, **kwds),
)
support = self._get_support(*args)
cdf, _ = self._get_xxf(*args)
cdf, _ = _get_xxf(self, *args)

cov = l_moment_cov_from_cdf(
cdf,
Expand Down Expand Up @@ -799,7 +745,7 @@ def l_stats_cov(
"""
args, _, scale = self._parse_args(*args, **kwds)
support = self._get_support(*args)
cdf, ppf = self._get_xxf(*args)
cdf, ppf = _get_xxf(self, *args)

cov = l_stats_cov_from_cdf(
cdf,
Expand Down Expand Up @@ -889,7 +835,7 @@ def l_moment_influence(
lm = self.l_moment(r, *args, trim=trim, quad_opts=quad_opts, **kwds)

args, loc, scale = self._parse_args(*args, **kwds)
cdf = self._get_xxf(*args, loc=loc, scale=scale)[0]
cdf, _ = _get_xxf(self, *args, loc=loc, scale=scale)

return l_moment_influence_from_cdf(
cdf,
Expand Down Expand Up @@ -983,7 +929,7 @@ def l_ratio_influence(
)

args, loc, scale = self._parse_args(*args, **kwds)
cdf = self._get_xxf(*args, loc=loc, scale=scale)[0]
cdf = _get_xxf(self, *args, loc=loc, scale=scale)[0]

return l_ratio_influence_from_cdf(
cdf,
Expand Down Expand Up @@ -1219,7 +1165,7 @@ def l_fit(
r = np.arange(1, len(args0) + n_extra + 1)

_lmo_cache: dict[tuple[float, ...], _ArrF8] = {}
_lmo_fn = self._l_moment
_lmo_fn = _l_moment

# temporary cache to speed up L-moment calculations with the same
# shape args
Expand All @@ -1234,7 +1180,7 @@ def lmo_fn(
if shapes in _lmo_cache:
lmbda_r = _lmo_cache[shapes]
else:
lmbda_r = _lmo_fn(_r, *shapes, **kwds)
lmbda_r = _lmo_fn(self, _r, *shapes, **kwds)
lmbda_r.setflags(write=False) # prevent cache corruption
_lmo_cache[shapes] = lmbda_r

Expand Down Expand Up @@ -1480,6 +1426,66 @@ def l_ratio_influence( # noqa: D102
)


def _l_moment(
self: l_rv_generic | rv_continuous,
r: npt.NDArray[np.intp],
/,
*args: Any,
trim: _Tuple2[int] | _Tuple2[float] = (0, 0),
quad_opts: lspt.QuadOptions | None = None,
) -> _ArrF8:
"""
Population L-moments of the standard distribution (i.e. `loc, scale = 0, 1`).
Todo:
Sparse caching: Use a priority queue with key `(self, args, r, trim)`.
Prefer small `r` and small `sum(trim)`, skip fractional trim.
"""
name = self.name
if quad_opts is None and has_lm_func(name):
with contextlib.suppress(NotImplementedError):
return get_lm_func(name)(r, *trim, *args)

cdf, ppf = _get_xxf(self, *args)

if prefers_ppf(name):
lmbda_r = l_moment_from_ppf(ppf, r, trim=trim, quad_opts=quad_opts)
else:
a, b = self._get_support(*args)
with np.errstate(over="ignore", under="ignore"):
lmbda_r = l_moment_from_cdf(
cdf,
r,
trim=trim,
support=(float(a), float(b)),
ppf=ppf,
quad_opts=quad_opts,
)

# re-wrap scalars in 0-d arrays (lmo.theoretical unpacks them)
return np.asarray(lmbda_r)


def _get_xxf(
self: l_rv_generic | rv_continuous,
/,
*shape: float,
loc: float = 0,
scale: float = 1,
) -> _Tuple2[_Fn1]:
assert scale > 0

_cdf, _ppf = self._cdf, self._ppf

def cdf(x: _T_x, /) -> _T_x:
return _cdf(np.array([(x - loc) / scale], dtype=np.float64), *shape)[0]

def ppf(q: _T_x, /) -> _T_x:
return _ppf(np.array([q], dtype=np.float64), *shape)[0] * scale + loc

return cdf, ppf


def install() -> None:
"""
Add the public methods from
Expand Down
5 changes: 1 addition & 4 deletions lmo/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,4 @@
[Distributions - GLD](../distributions.md#gld).
"""
else:
genlambda: Final = cast(
lspt.RVContinuous,
genlambda_gen(name="genlambda"),
)
genlambda: Final = cast(lspt.RVContinuous, genlambda_gen(name="genlambda"))
Loading

0 comments on commit 39520bb

Please sign in to comment.