Skip to content

Commit

Permalink
Merge pull request numpy#25086 from mtsokol/array-api-aliases
Browse files Browse the repository at this point in the history
API: Add Array API aliases (math, bitwise, linalg, misc) [Array API]
  • Loading branch information
ngoldbaum authored Dec 8, 2023
2 parents a0ec186 + 0c3cad6 commit a5b67bb
Show file tree
Hide file tree
Showing 17 changed files with 213 additions and 125 deletions.
12 changes: 12 additions & 0 deletions doc/release/upcoming_changes/25086.new_feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
Array API compatible functions' aliases
---------------------------------------

13 aliases for existing functions were added to improve compatibility with the Array API standard:

* Trigonometry: ``acos``, ``acosh``, ``asin``, ``asinh``, ``atan``, ``atanh``, ``atan2``.

* Bitwise: ``bitwise_left_shift``, ``bitwise_invert``, ``bitwise_right_shift``.

* Misc: ``concat``, ``permute_dims``, ``pow``.

* linalg: ``tensordot``, ``matmul``.
56 changes: 0 additions & 56 deletions doc/source/reference/array_api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -71,66 +71,10 @@ The following functions are named differently in the array API
* - Array API name
- NumPy namespace name
- Notes
* - ``acos``
- ``arccos``
-
* - ``acosh``
- ``arccosh``
-
* - ``asin``
- ``arcsin``
-
* - ``asinh``
- ``arcsinh``
-
* - ``atan``
- ``arctan``
-
* - ``atan2``
- ``arctan2``
-
* - ``atanh``
- ``arctanh``
-
* - ``bitwise_left_shift``
- ``left_shift``
-
* - ``bitwise_invert``
- ``invert``
-
* - ``bitwise_right_shift``
- ``right_shift``
-
* - ``bool``
- ``bool_``
- This is **breaking** because ``np.bool`` is currently a deprecated
alias for the built-in ``bool``.
* - ``concat``
- ``concatenate``
-
* - ``matrix_norm`` and ``vector_norm``
- ``norm``
- ``matrix_norm`` and ``vector_norm`` each do a limited subset of what
``np.norm`` does.
* - ``permute_dims``
- ``transpose``
- Unlike ``np.transpose``, the ``axis`` keyword-argument to
``permute_dims`` is required.
* - ``pow``
- ``power``
-


``linalg`` Namespace Differences
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

These functions are in the ``linalg`` sub-namespace in the array API, but are
only in the top-level namespace in NumPy:

- ``matmul`` (*)
- ``tensordot`` (*)

(*): These functions are also in the top-level namespace in the array API.

Keyword Argument Renames
~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down
2 changes: 2 additions & 0 deletions doc/source/reference/routines.array-manipulation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ Transpose-like operations
swapaxes
ndarray.T
transpose
permute_dims

Changing number of dimensions
=============================
Expand Down Expand Up @@ -66,6 +67,7 @@ Joining arrays
:toctree: generated/

concatenate
concat
stack
block
vstack
Expand Down
3 changes: 3 additions & 0 deletions doc/source/reference/routines.bitwise.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,11 @@ Elementwise bit operations
bitwise_or
bitwise_xor
invert
bitwise_invert
left_shift
bitwise_left_shift
right_shift
bitwise_right_shift

Bit packing
-----------
Expand Down
2 changes: 2 additions & 0 deletions doc/source/reference/routines.linalg.rst
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,9 @@ Matrix and vector products
inner
outer
matmul
linalg.matmul (Array API compatible location)
tensordot
linalg.tensordot (Array API compatible location)
einsum
einsum_path
linalg.matrix_power
Expand Down
8 changes: 8 additions & 0 deletions doc/source/reference/routines.math.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,14 @@ Trigonometric functions
cos
tan
arcsin
asin
arccos
acos
arctan
atan
hypot
arctan2
atan2
degrees
radians
unwrap
Expand All @@ -31,8 +35,11 @@ Hyperbolic functions
cosh
tanh
arcsinh
asinh
arccosh
acosh
arctanh
atanh

Rounding
--------
Expand Down Expand Up @@ -120,6 +127,7 @@ Arithmetic operations
multiply
divide
power
pow
subtract
true_divide
floor_divide
Expand Down
61 changes: 31 additions & 30 deletions numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,24 +121,25 @@
from . import _core
from ._core import (
False_, ScalarType, True_, _get_promotion_state, _no_nep50_warning,
_set_promotion_state, abs, absolute, add, all, allclose, alltrue,
amax, amin, any, arange, arccos, arccosh, arcsin, arcsinh, arctan,
arctan2, arctanh, argmax, argmin, argpartition, argsort, argwhere,
around, array, array2string, array_equal, array_equiv, array_repr,
array_str, asanyarray, asarray, ascontiguousarray, asfortranarray,
astype, atleast_1d, atleast_2d, atleast_3d, base_repr, binary_repr,
bitwise_and, bitwise_count, bitwise_not, bitwise_or, bitwise_xor,
block, bool, bool_, broadcast, busday_count, busday_offset,
busdaycalendar, byte, bytes_, can_cast, cbrt, cdouble, ceil,
character, choose, clip, clongdouble, complex128, complex64,
complexfloating, compress, concatenate, conj, conjugate, convolve,
copysign, copyto, correlate, cos, cosh, count_nonzero, cross, csingle,
cumprod, cumproduct, cumsum, datetime64, datetime_as_string,
datetime_data, deg2rad, degrees, diagonal, divide, divmod, dot,
double, dtype, e, einsum, einsum_path, empty, empty_like, equal,
errstate, euler_gamma, exp, exp2, expm1, fabs, finfo, flatiter,
flatnonzero, flexible, float16, float32, float64, float_power,
floating, floor, floor_divide, fmax, fmin, fmod,
_set_promotion_state, abs, absolute, acos, acosh, add, all, allclose,
alltrue, amax, amin, any, arange, arccos, arccosh, arcsin, arcsinh,
arctan, arctan2, arctanh, argmax, argmin, argpartition, argsort,
argwhere, around, array, array2string, array_equal, array_equiv,
array_repr, array_str, asanyarray, asarray, ascontiguousarray,
asfortranarray, asin, asinh, atan, atanh, atan2, astype, atleast_1d,
atleast_2d, atleast_3d, base_repr, binary_repr, bitwise_and,
bitwise_count, bitwise_invert, bitwise_left_shift, bitwise_not,
bitwise_or, bitwise_right_shift, bitwise_xor, block, bool, bool_,
broadcast, busday_count, busday_offset, busdaycalendar, byte, bytes_,
can_cast, cbrt, cdouble, ceil, character, choose, clip, clongdouble,
complex128, complex64, complexfloating, compress, concat, concatenate,
conj, conjugate, convolve, copysign, copyto, correlate, cos, cosh,
count_nonzero, cross, csingle, cumprod, cumproduct, cumsum,
datetime64, datetime_as_string, datetime_data, deg2rad, degrees,
diagonal, divide, divmod, dot, double, dtype, e, einsum, einsum_path,
empty, empty_like, equal, errstate, euler_gamma, exp, exp2, expm1,
fabs, finfo, flatiter, flatnonzero, flexible, float16, float32,
float64, float_power, floating, floor, floor_divide, fmax, fmin, fmod,
format_float_positional, format_float_scientific, frexp, from_dlpack,
frombuffer, fromfile, fromfunction, fromiter, frompyfunc, fromstring,
full, full_like, gcd, generic, geomspace, get_printoptions,
Expand All @@ -153,18 +154,18 @@
may_share_memory, mean, memmap, min, min_scalar_type, minimum, mod,
modf, moveaxis, multiply, nan, ndarray, ndim, nditer, negative,
nested_iters, newaxis, nextafter, nonzero, not_equal, number, object_,
ones, ones_like, outer, partition, pi, positive, power, printoptions,
prod, product, promote_types, ptp, put, putmask, rad2deg, radians,
ravel, recarray, reciprocal, record, remainder, repeat, require,
reshape, resize, result_type, right_shift, rint, roll, rollaxis,
round, sctypeDict, searchsorted, set_printoptions, setbufsize, seterr,
seterrcall, shape, shares_memory, short, sign, signbit, signedinteger,
sin, single, sinh, size, sometrue, sort, spacing, sqrt, square,
squeeze, stack, std, str_, subtract, sum, swapaxes, take, tan, tanh,
tensordot, timedelta64, trace, transpose, true_divide, trunc,
typecodes, ubyte, ufunc, uint, uint16, uint32, uint64, uint8, uintc,
uintp, ulong, ulonglong, unsignedinteger, ushort, var, vdot, void,
vstack, where, zeros, zeros_like
ones, ones_like, outer, partition, permute_dims, pi, positive, pow,
power, printoptions, prod, product, promote_types, ptp, put, putmask,
rad2deg, radians, ravel, recarray, reciprocal, record, remainder,
repeat, require, reshape, resize, result_type, right_shift, rint,
roll, rollaxis, round, sctypeDict, searchsorted, set_printoptions,
setbufsize, seterr, seterrcall, shape, shares_memory, short, sign,
signbit, signedinteger, sin, single, sinh, size, sometrue, sort,
spacing, sqrt, square, squeeze, stack, std, str_, subtract, sum,
swapaxes, take, tan, tanh, tensordot, timedelta64, trace, transpose,
true_divide, trunc, typecodes, ubyte, ufunc, uint, uint16, uint32,
uint64, uint8, uintc, uintp, ulong, ulonglong, unsignedinteger,
ushort, var, vdot, void, vstack, where, zeros, zeros_like
)

# NOTE: It's still under discussion whether these aliases
Expand Down
13 changes: 13 additions & 0 deletions numpy/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -3248,6 +3248,19 @@ true_divide: _UFunc_Nin2_Nout1[L['true_divide'], L[11], None]
trunc: _UFunc_Nin1_Nout1[L['trunc'], L[7], None]

abs = absolute
acos = arccos
acosh = arccosh
asin = arcsin
asinh = arcsinh
atan = arctan
atanh = arctanh
atan2 = arctan2
concat = concatenate
bitwise_left_shift = left_shift
bitwise_invert = invert
bitwise_right_shift = right_shift
permute_dims = transpose
pow = power

class _CopyMode(enum.Enum):
ALWAYS: L[True]
Expand Down
20 changes: 19 additions & 1 deletion numpy/_core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,25 @@
from . import _dtype
from . import _methods

__all__ = ['memmap', 'sctypeDict', 'record', 'recarray', 'abs']
acos = numeric.arccos
acosh = numeric.arccosh
asin = numeric.arcsin
asinh = numeric.arcsinh
atan = numeric.arctan
atanh = numeric.arctanh
atan2 = numeric.arctan2
concat = numeric.concatenate
bitwise_left_shift = numeric.left_shift
bitwise_invert = numeric.invert
bitwise_right_shift = numeric.right_shift
permute_dims = numeric.transpose
pow = numeric.power

__all__ = [
"abs", "acos", "acosh", "asin", "asinh", "atan", "atanh", "atan2",
"bitwise_invert", "bitwise_left_shift", "bitwise_right_shift", "concat",
"pow", "permute_dims", "memmap", "sctypeDict", "record", "recarray"
]
__all__ += numeric.__all__
__all__ += function_base.__all__
__all__ += getlimits.__all__
Expand Down
2 changes: 2 additions & 0 deletions numpy/linalg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
cross
multi_dot
matrix_power
tensordot
matmul
Decompositions
--------------
Expand Down
5 changes: 5 additions & 0 deletions numpy/linalg/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,16 @@ from numpy.linalg._linalg import (
cond as cond,
matrix_rank as matrix_rank,
multi_dot as multi_dot,
matmul as matmul,
trace as trace,
diagonal as diagonal,
cross as cross,
)

from numpy._core.numeric import (
tensordot as tensordot,
)

from numpy._pytesttester import PytestTester

__all__: list[str]
Expand Down
63 changes: 61 additions & 2 deletions numpy/linalg/_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
'cholesky', 'eigvals', 'eigvalsh', 'pinv', 'slogdet', 'det',
'svd', 'svdvals', 'eig', 'eigh', 'lstsq', 'norm', 'qr', 'cond',
'matrix_rank', 'LinAlgError', 'multi_dot', 'trace', 'diagonal',
'cross', 'outer']
'cross', 'outer', 'tensordot', 'matmul']

import functools
import operator
Expand All @@ -28,7 +28,8 @@
amax, prod, abs, atleast_2d, intp, asanyarray, object_, matmul,
swapaxes, divide, count_nonzero, isnan, sign, argsort, sort,
reciprocal, overrides, diagonal as _core_diagonal, trace as _core_trace,
cross as _core_cross, outer as _core_outer
cross as _core_cross, outer as _core_outer, tensordot as _core_tensordot,
matmul as _core_matmul,
)
from numpy.lib._twodim_base_impl import triu, eye
from numpy.lib.array_utils import normalize_axis_index
Expand Down Expand Up @@ -3129,3 +3130,61 @@ def cross(x1, x2, /, *, axis=-1):
)

return _core_cross(x1, x2, axis=axis)


# matmul

def _matmul_dispatcher(x1, x2, /):
return (x1, x2)


@array_function_dispatch(_matmul_dispatcher)
def matmul(x1, x2, /):
"""
Computes the matrix product.
This function is Array API compatible, contrary to
:func:`numpy.matmul`.
Parameters
----------
x1 : array_like
The first input array.
x2 : array_like
The second input array.
Returns
-------
out : ndarray
The matrix product of the inputs.
This is a scalar only when both ``x1``, ``x2`` are 1-d vectors.
Raises
------
ValueError
If the last dimension of ``x1`` is not the same size as
the second-to-last dimension of ``x2``.
If a scalar value is passed in.
See Also
--------
numpy.matmul
"""
return _core_matmul(x1, x2)


# tensordot

def _tensordot_dispatcher(
x1, x2, /, *, offset=None, dtype=None):
return (x1, x2)


@array_function_dispatch(_tensordot_dispatcher)
def tensordot(x1, x2, /, *, axes=2):
return _core_tensordot(x1, x2, axes=axes)


tensordot.__doc__ = _core_tensordot.__doc__
Loading

0 comments on commit a5b67bb

Please sign in to comment.