summaryrefslogtreecommitdiffstats
path: root/venv/lib/python3.9/site-packages/numpy/array_api
diff options
context:
space:
mode:
authornoptuno <repollo.marrero@gmail.com>2023-04-28 02:40:47 +0200
committernoptuno <repollo.marrero@gmail.com>2023-04-28 02:40:47 +0200
commit6f6a73987201c9c303047c61389b82ad98b15597 (patch)
treebf67eb590d49979d6740bc1e94b4018df48bce98 /venv/lib/python3.9/site-packages/numpy/array_api
parentResolved merge conflicts and merged pr_218 into STREAMLIT_CHAT_IMPLEMENTATION (diff)
parentMerging PR_218 openai_rev package with new streamlit chat app (diff)
downloadgpt4free-6f6a73987201c9c303047c61389b82ad98b15597.tar
gpt4free-6f6a73987201c9c303047c61389b82ad98b15597.tar.gz
gpt4free-6f6a73987201c9c303047c61389b82ad98b15597.tar.bz2
gpt4free-6f6a73987201c9c303047c61389b82ad98b15597.tar.lz
gpt4free-6f6a73987201c9c303047c61389b82ad98b15597.tar.xz
gpt4free-6f6a73987201c9c303047c61389b82ad98b15597.tar.zst
gpt4free-6f6a73987201c9c303047c61389b82ad98b15597.zip
Diffstat (limited to 'venv/lib/python3.9/site-packages/numpy/array_api')
-rw-r--r--venv/lib/python3.9/site-packages/numpy/array_api/__init__.py377
-rw-r--r--venv/lib/python3.9/site-packages/numpy/array_api/_array_object.py1118
-rw-r--r--venv/lib/python3.9/site-packages/numpy/array_api/_constants.py6
-rw-r--r--venv/lib/python3.9/site-packages/numpy/array_api/_creation_functions.py351
-rw-r--r--venv/lib/python3.9/site-packages/numpy/array_api/_data_type_functions.py146
-rw-r--r--venv/lib/python3.9/site-packages/numpy/array_api/_dtypes.py143
-rw-r--r--venv/lib/python3.9/site-packages/numpy/array_api/_elementwise_functions.py729
-rw-r--r--venv/lib/python3.9/site-packages/numpy/array_api/_manipulation_functions.py98
-rw-r--r--venv/lib/python3.9/site-packages/numpy/array_api/_searching_functions.py47
-rw-r--r--venv/lib/python3.9/site-packages/numpy/array_api/_set_functions.py106
-rw-r--r--venv/lib/python3.9/site-packages/numpy/array_api/_sorting_functions.py49
-rw-r--r--venv/lib/python3.9/site-packages/numpy/array_api/_statistical_functions.py115
-rw-r--r--venv/lib/python3.9/site-packages/numpy/array_api/_typing.py74
-rw-r--r--venv/lib/python3.9/site-packages/numpy/array_api/_utility_functions.py37
-rw-r--r--venv/lib/python3.9/site-packages/numpy/array_api/linalg.py446
-rw-r--r--venv/lib/python3.9/site-packages/numpy/array_api/setup.py12
-rw-r--r--venv/lib/python3.9/site-packages/numpy/array_api/tests/__init__.py7
-rw-r--r--venv/lib/python3.9/site-packages/numpy/array_api/tests/test_array_object.py375
-rw-r--r--venv/lib/python3.9/site-packages/numpy/array_api/tests/test_creation_functions.py142
-rw-r--r--venv/lib/python3.9/site-packages/numpy/array_api/tests/test_data_type_functions.py19
-rw-r--r--venv/lib/python3.9/site-packages/numpy/array_api/tests/test_elementwise_functions.py111
-rw-r--r--venv/lib/python3.9/site-packages/numpy/array_api/tests/test_set_functions.py19
-rw-r--r--venv/lib/python3.9/site-packages/numpy/array_api/tests/test_sorting_functions.py23
-rw-r--r--venv/lib/python3.9/site-packages/numpy/array_api/tests/test_validation.py27
24 files changed, 4577 insertions, 0 deletions
diff --git a/venv/lib/python3.9/site-packages/numpy/array_api/__init__.py b/venv/lib/python3.9/site-packages/numpy/array_api/__init__.py
new file mode 100644
index 00000000..5e58ee0a
--- /dev/null
+++ b/venv/lib/python3.9/site-packages/numpy/array_api/__init__.py
@@ -0,0 +1,377 @@
+"""
+A NumPy sub-namespace that conforms to the Python array API standard.
+
+This submodule accompanies NEP 47, which proposes its inclusion in NumPy. It
+is still considered experimental, and will issue a warning when imported.
+
+This is a proof-of-concept namespace that wraps the corresponding NumPy
+functions to give a conforming implementation of the Python array API standard
+(https://data-apis.github.io/array-api/latest/). The standard is currently in
+an RFC phase and comments on it are both welcome and encouraged. Comments
+should be made either at https://github.com/data-apis/array-api or at
+https://github.com/data-apis/consortium-feedback/discussions.
+
+NumPy already follows the proposed spec for the most part, so this module
+serves mostly as a thin wrapper around it. However, NumPy also implements a
+lot of behavior that is not included in the spec, so this serves as a
+restricted subset of the API. Only those functions that are part of the spec
+are included in this namespace, and all functions are given with the exact
+signature given in the spec, including the use of position-only arguments, and
+omitting any extra keyword arguments implemented by NumPy but not part of the
+spec. The behavior of some functions is also modified from the NumPy behavior
+to conform to the standard. Note that the underlying array object itself is
+wrapped in a wrapper Array() class, but is otherwise unchanged. This submodule
+is implemented in pure Python with no C extensions.
+
+The array API spec is designed as a "minimal API subset" and explicitly allows
+libraries to include behaviors not specified by it. But users of this module
+that intend to write portable code should be aware that only those behaviors
+that are listed in the spec are guaranteed to be implemented across libraries.
+Consequently, the NumPy implementation was chosen to be both conforming and
+minimal, so that users can use this implementation of the array API namespace
+and be sure that behaviors that it defines will be available in conforming
+namespaces from other libraries.
+
+A few notes about the current state of this submodule:
+
+- There is a test suite that tests modules against the array API standard at
+ https://github.com/data-apis/array-api-tests. The test suite is still a work
+ in progress, but the existing tests pass on this module, with a few
+ exceptions:
+
+ - DLPack support (see https://github.com/data-apis/array-api/pull/106) is
+ not included here, as it requires a full implementation in NumPy proper
+ first.
+
+ The test suite is not yet complete, and even the tests that exist are not
+ guaranteed to give a comprehensive coverage of the spec. Therefore, when
+ reviewing and using this submodule, you should refer to the standard
+ documents themselves. There are some tests in numpy.array_api.tests, but
+ they primarily focus on things that are not tested by the official array API
+ test suite.
+
+- There is a custom array object, numpy.array_api.Array, which is returned by
+ all functions in this module. All functions in the array API namespace
+ implicitly assume that they will only receive this object as input. The only
+ way to create instances of this object is to use one of the array creation
+ functions. It does not have a public constructor on the object itself. The
+ object is a small wrapper class around numpy.ndarray. The main purpose of it
+ is to restrict the namespace of the array object to only those dtypes and
+ only those methods that are required by the spec, as well as to limit/change
+ certain behavior that differs in the spec. In particular:
+
+ - The array API namespace does not have scalar objects, only 0-D arrays.
+ Operations on Array that would create a scalar in NumPy create a 0-D
+ array.
+
+ - Indexing: Only a subset of indices supported by NumPy are required by the
+ spec. The Array object restricts indexing to only allow those types of
+ indices that are required by the spec. See the docstring of the
+ numpy.array_api.Array._validate_indices helper function for more
+ information.
+
+ - Type promotion: Some type promotion rules are different in the spec. In
+ particular, the spec does not have any value-based casting. The spec also
+ does not require cross-kind casting, like integer -> floating-point. Only
+ those promotions that are explicitly required by the array API
+ specification are allowed in this module. See NEP 47 for more info.
+
+ - Functions do not automatically call asarray() on their input, and will not
+ work if the input type is not Array. The exception is array creation
+ functions, and Python operators on the Array object, which accept Python
+ scalars of the same type as the array dtype.
+
+- All functions include type annotations, corresponding to those given in the
+ spec (see _typing.py for definitions of some custom types). These do not
+ currently fully pass mypy due to some limitations in mypy.
+
+- Dtype objects are just the NumPy dtype objects, e.g., float64 =
+ np.dtype('float64'). The spec does not require any behavior on these dtype
+ objects other than that they be accessible by name and be comparable by
+ equality, but it was considered too much extra complexity to create custom
+ objects to represent dtypes.
+
+- All places where the implementations in this submodule are known to deviate
+ from their corresponding functions in NumPy are marked with "# Note:"
+ comments.
+
+Still TODO in this module are:
+
+- DLPack support for numpy.ndarray is still in progress. See
+ https://github.com/numpy/numpy/pull/19083.
+
+- The copy=False keyword argument to asarray() is not yet implemented. This
+ requires support in numpy.asarray() first.
+
+- Some functions are not yet fully tested in the array API test suite, and may
+ require updates that are not yet known until the tests are written.
+
+- The spec is still in an RFC phase and may still have minor updates, which
+ will need to be reflected here.
+
+- Complex number support in array API spec is planned but not yet finalized,
+ as are the fft extension and certain linear algebra functions such as eig
+ that require complex dtypes.
+
+"""
+
+import warnings
+
+warnings.warn(
+ "The numpy.array_api submodule is still experimental. See NEP 47.", stacklevel=2
+)
+
+__array_api_version__ = "2021.12"
+
+__all__ = ["__array_api_version__"]
+
+from ._constants import e, inf, nan, pi
+
+__all__ += ["e", "inf", "nan", "pi"]
+
+from ._creation_functions import (
+ asarray,
+ arange,
+ empty,
+ empty_like,
+ eye,
+ from_dlpack,
+ full,
+ full_like,
+ linspace,
+ meshgrid,
+ ones,
+ ones_like,
+ tril,
+ triu,
+ zeros,
+ zeros_like,
+)
+
+__all__ += [
+ "asarray",
+ "arange",
+ "empty",
+ "empty_like",
+ "eye",
+ "from_dlpack",
+ "full",
+ "full_like",
+ "linspace",
+ "meshgrid",
+ "ones",
+ "ones_like",
+ "tril",
+ "triu",
+ "zeros",
+ "zeros_like",
+]
+
+from ._data_type_functions import (
+ astype,
+ broadcast_arrays,
+ broadcast_to,
+ can_cast,
+ finfo,
+ iinfo,
+ result_type,
+)
+
+__all__ += [
+ "astype",
+ "broadcast_arrays",
+ "broadcast_to",
+ "can_cast",
+ "finfo",
+ "iinfo",
+ "result_type",
+]
+
+from ._dtypes import (
+ int8,
+ int16,
+ int32,
+ int64,
+ uint8,
+ uint16,
+ uint32,
+ uint64,
+ float32,
+ float64,
+ bool,
+)
+
+__all__ += [
+ "int8",
+ "int16",
+ "int32",
+ "int64",
+ "uint8",
+ "uint16",
+ "uint32",
+ "uint64",
+ "float32",
+ "float64",
+ "bool",
+]
+
+from ._elementwise_functions import (
+ abs,
+ acos,
+ acosh,
+ add,
+ asin,
+ asinh,
+ atan,
+ atan2,
+ atanh,
+ bitwise_and,
+ bitwise_left_shift,
+ bitwise_invert,
+ bitwise_or,
+ bitwise_right_shift,
+ bitwise_xor,
+ ceil,
+ cos,
+ cosh,
+ divide,
+ equal,
+ exp,
+ expm1,
+ floor,
+ floor_divide,
+ greater,
+ greater_equal,
+ isfinite,
+ isinf,
+ isnan,
+ less,
+ less_equal,
+ log,
+ log1p,
+ log2,
+ log10,
+ logaddexp,
+ logical_and,
+ logical_not,
+ logical_or,
+ logical_xor,
+ multiply,
+ negative,
+ not_equal,
+ positive,
+ pow,
+ remainder,
+ round,
+ sign,
+ sin,
+ sinh,
+ square,
+ sqrt,
+ subtract,
+ tan,
+ tanh,
+ trunc,
+)
+
+__all__ += [
+ "abs",
+ "acos",
+ "acosh",
+ "add",
+ "asin",
+ "asinh",
+ "atan",
+ "atan2",
+ "atanh",
+ "bitwise_and",
+ "bitwise_left_shift",
+ "bitwise_invert",
+ "bitwise_or",
+ "bitwise_right_shift",
+ "bitwise_xor",
+ "ceil",
+ "cos",
+ "cosh",
+ "divide",
+ "equal",
+ "exp",
+ "expm1",
+ "floor",
+ "floor_divide",
+ "greater",
+ "greater_equal",
+ "isfinite",
+ "isinf",
+ "isnan",
+ "less",
+ "less_equal",
+ "log",
+ "log1p",
+ "log2",
+ "log10",
+ "logaddexp",
+ "logical_and",
+ "logical_not",
+ "logical_or",
+ "logical_xor",
+ "multiply",
+ "negative",
+ "not_equal",
+ "positive",
+ "pow",
+ "remainder",
+ "round",
+ "sign",
+ "sin",
+ "sinh",
+ "square",
+ "sqrt",
+ "subtract",
+ "tan",
+ "tanh",
+ "trunc",
+]
+
+# linalg is an extension in the array API spec, which is a sub-namespace. Only
+# a subset of functions in it are imported into the top-level namespace.
+from . import linalg
+
+__all__ += ["linalg"]
+
+from .linalg import matmul, tensordot, matrix_transpose, vecdot
+
+__all__ += ["matmul", "tensordot", "matrix_transpose", "vecdot"]
+
+from ._manipulation_functions import (
+ concat,
+ expand_dims,
+ flip,
+ permute_dims,
+ reshape,
+ roll,
+ squeeze,
+ stack,
+)
+
+__all__ += ["concat", "expand_dims", "flip", "permute_dims", "reshape", "roll", "squeeze", "stack"]
+
+from ._searching_functions import argmax, argmin, nonzero, where
+
+__all__ += ["argmax", "argmin", "nonzero", "where"]
+
+from ._set_functions import unique_all, unique_counts, unique_inverse, unique_values
+
+__all__ += ["unique_all", "unique_counts", "unique_inverse", "unique_values"]
+
+from ._sorting_functions import argsort, sort
+
+__all__ += ["argsort", "sort"]
+
+from ._statistical_functions import max, mean, min, prod, std, sum, var
+
+__all__ += ["max", "mean", "min", "prod", "std", "sum", "var"]
+
+from ._utility_functions import all, any
+
+__all__ += ["all", "any"]
diff --git a/venv/lib/python3.9/site-packages/numpy/array_api/_array_object.py b/venv/lib/python3.9/site-packages/numpy/array_api/_array_object.py
new file mode 100644
index 00000000..c4746fad
--- /dev/null
+++ b/venv/lib/python3.9/site-packages/numpy/array_api/_array_object.py
@@ -0,0 +1,1118 @@
+"""
+Wrapper class around the ndarray object for the array API standard.
+
+The array API standard defines some behaviors differently than ndarray, in
+particular, type promotion rules are different (the standard has no
+value-based casting). The standard also specifies a more limited subset of
+array methods and functionalities than are implemented on ndarray. Since the
+goal of the array_api namespace is to be a minimal implementation of the array
+API standard, we need to define a separate wrapper class for the array_api
+namespace.
+
+The standard compliant class is only a wrapper class. It is *not* a subclass
+of ndarray.
+"""
+
+from __future__ import annotations
+
+import operator
+from enum import IntEnum
+from ._creation_functions import asarray
+from ._dtypes import (
+ _all_dtypes,
+ _boolean_dtypes,
+ _integer_dtypes,
+ _integer_or_boolean_dtypes,
+ _floating_dtypes,
+ _numeric_dtypes,
+ _result_type,
+ _dtype_categories,
+)
+
+from typing import TYPE_CHECKING, Optional, Tuple, Union, Any, SupportsIndex
+import types
+
+if TYPE_CHECKING:
+ from ._typing import Any, PyCapsule, Device, Dtype
+ import numpy.typing as npt
+
+import numpy as np
+
+from numpy import array_api
+
+
+class Array:
+ """
+ n-d array object for the array API namespace.
+
+ See the docstring of :py:obj:`np.ndarray <numpy.ndarray>` for more
+ information.
+
+ This is a wrapper around numpy.ndarray that restricts the usage to only
+ those things that are required by the array API namespace. Note,
+ attributes on this object that start with a single underscore are not part
+ of the API specification and should only be used internally. This object
+ should not be constructed directly. Rather, use one of the creation
+ functions, such as asarray().
+
+ """
+ _array: np.ndarray
+
+ # Use a custom constructor instead of __init__, as manually initializing
+ # this class is not supported API.
+ @classmethod
+ def _new(cls, x, /):
+ """
+ This is a private method for initializing the array API Array
+ object.
+
+ Functions outside of the array_api submodule should not use this
+ method. Use one of the creation functions instead, such as
+ ``asarray``.
+
+ """
+ obj = super().__new__(cls)
+ # Note: The spec does not have array scalars, only 0-D arrays.
+ if isinstance(x, np.generic):
+ # Convert the array scalar to a 0-D array
+ x = np.asarray(x)
+ if x.dtype not in _all_dtypes:
+ raise TypeError(
+ f"The array_api namespace does not support the dtype '{x.dtype}'"
+ )
+ obj._array = x
+ return obj
+
+ # Prevent Array() from working
+ def __new__(cls, *args, **kwargs):
+ raise TypeError(
+ "The array_api Array object should not be instantiated directly. Use an array creation function, such as asarray(), instead."
+ )
+
+ # These functions are not required by the spec, but are implemented for
+ # the sake of usability.
+
+ def __str__(self: Array, /) -> str:
+ """
+ Performs the operation __str__.
+ """
+ return self._array.__str__().replace("array", "Array")
+
+ def __repr__(self: Array, /) -> str:
+ """
+ Performs the operation __repr__.
+ """
+ suffix = f", dtype={self.dtype.name})"
+ if 0 in self.shape:
+ prefix = "empty("
+ mid = str(self.shape)
+ else:
+ prefix = "Array("
+ mid = np.array2string(self._array, separator=', ', prefix=prefix, suffix=suffix)
+ return prefix + mid + suffix
+
+ # This function is not required by the spec, but we implement it here for
+ # convenience so that np.asarray(np.array_api.Array) will work.
+ def __array__(self, dtype: None | np.dtype[Any] = None) -> npt.NDArray[Any]:
+ """
+ Warning: this method is NOT part of the array API spec. Implementers
+ of other libraries need not include it, and users should not assume it
+ will be present in other implementations.
+
+ """
+ return np.asarray(self._array, dtype=dtype)
+
+ # These are various helper functions to make the array behavior match the
+ # spec in places where it either deviates from or is more strict than
+ # NumPy behavior
+
+ def _check_allowed_dtypes(self, other: bool | int | float | Array, dtype_category: str, op: str) -> Array:
+ """
+ Helper function for operators to only allow specific input dtypes
+
+ Use like
+
+ other = self._check_allowed_dtypes(other, 'numeric', '__add__')
+ if other is NotImplemented:
+ return other
+ """
+
+ if self.dtype not in _dtype_categories[dtype_category]:
+ raise TypeError(f"Only {dtype_category} dtypes are allowed in {op}")
+ if isinstance(other, (int, float, bool)):
+ other = self._promote_scalar(other)
+ elif isinstance(other, Array):
+ if other.dtype not in _dtype_categories[dtype_category]:
+ raise TypeError(f"Only {dtype_category} dtypes are allowed in {op}")
+ else:
+ return NotImplemented
+
+ # This will raise TypeError for type combinations that are not allowed
+ # to promote in the spec (even if the NumPy array operator would
+ # promote them).
+ res_dtype = _result_type(self.dtype, other.dtype)
+ if op.startswith("__i"):
+ # Note: NumPy will allow in-place operators in some cases where
+ # the type promoted operator does not match the left-hand side
+ # operand. For example,
+
+ # >>> a = np.array(1, dtype=np.int8)
+ # >>> a += np.array(1, dtype=np.int16)
+
+ # The spec explicitly disallows this.
+ if res_dtype != self.dtype:
+ raise TypeError(
+ f"Cannot perform {op} with dtypes {self.dtype} and {other.dtype}"
+ )
+
+ return other
+
+ # Helper function to match the type promotion rules in the spec
+ def _promote_scalar(self, scalar):
+ """
+ Returns a promoted version of a Python scalar appropriate for use with
+ operations on self.
+
+ This may raise an OverflowError in cases where the scalar is an
+ integer that is too large to fit in a NumPy integer dtype, or
+ TypeError when the scalar type is incompatible with the dtype of self.
+ """
+ # Note: Only Python scalar types that match the array dtype are
+ # allowed.
+ if isinstance(scalar, bool):
+ if self.dtype not in _boolean_dtypes:
+ raise TypeError(
+ "Python bool scalars can only be promoted with bool arrays"
+ )
+ elif isinstance(scalar, int):
+ if self.dtype in _boolean_dtypes:
+ raise TypeError(
+ "Python int scalars cannot be promoted with bool arrays"
+ )
+ elif isinstance(scalar, float):
+ if self.dtype not in _floating_dtypes:
+ raise TypeError(
+ "Python float scalars can only be promoted with floating-point arrays."
+ )
+ else:
+ raise TypeError("'scalar' must be a Python scalar")
+
+ # Note: scalars are unconditionally cast to the same dtype as the
+ # array.
+
+ # Note: the spec only specifies integer-dtype/int promotion
+ # behavior for integers within the bounds of the integer dtype.
+ # Outside of those bounds we use the default NumPy behavior (either
+ # cast or raise OverflowError).
+ return Array._new(np.array(scalar, self.dtype))
+
+ @staticmethod
+ def _normalize_two_args(x1, x2) -> Tuple[Array, Array]:
+ """
+ Normalize inputs to two arg functions to fix type promotion rules
+
+ NumPy deviates from the spec type promotion rules in cases where one
+ argument is 0-dimensional and the other is not. For example:
+
+ >>> import numpy as np
+ >>> a = np.array([1.0], dtype=np.float32)
+ >>> b = np.array(1.0, dtype=np.float64)
+ >>> np.add(a, b) # The spec says this should be float64
+ array([2.], dtype=float32)
+
+ To fix this, we add a dimension to the 0-dimension array before passing it
+ through. This works because a dimension would be added anyway from
+ broadcasting, so the resulting shape is the same, but this prevents NumPy
+ from not promoting the dtype.
+ """
+ # Another option would be to use signature=(x1.dtype, x2.dtype, None),
+ # but that only works for ufuncs, so we would have to call the ufuncs
+ # directly in the operator methods. One should also note that this
+ # sort of trick wouldn't work for functions like searchsorted, which
+ # don't do normal broadcasting, but there aren't any functions like
+ # that in the array API namespace.
+ if x1.ndim == 0 and x2.ndim != 0:
+ # The _array[None] workaround was chosen because it is relatively
+ # performant. broadcast_to(x1._array, x2.shape) is much slower. We
+ # could also manually type promote x2, but that is more complicated
+ # and about the same performance as this.
+ x1 = Array._new(x1._array[None])
+ elif x2.ndim == 0 and x1.ndim != 0:
+ x2 = Array._new(x2._array[None])
+ return (x1, x2)
+
+ # Note: A large fraction of allowed indices are disallowed here (see the
+ # docstring below)
+ def _validate_index(self, key):
+ """
+ Validate an index according to the array API.
+
+ The array API specification only requires a subset of indices that are
+ supported by NumPy. This function will reject any index that is
+ allowed by NumPy but not required by the array API specification. We
+ always raise ``IndexError`` on such indices (the spec does not require
+ any specific behavior on them, but this makes the NumPy array API
+ namespace a minimal implementation of the spec). See
+ https://data-apis.org/array-api/latest/API_specification/indexing.html
+ for the full list of required indexing behavior
+
+ This function raises IndexError if the index ``key`` is invalid. It
+ only raises ``IndexError`` on indices that are not already rejected by
+ NumPy, as NumPy will already raise the appropriate error on such
+ indices. ``shape`` may be None, in which case, only cases that are
+ independent of the array shape are checked.
+
+ The following cases are allowed by NumPy, but not specified by the array
+ API specification:
+
+ - Indices to not include an implicit ellipsis at the end. That is,
+ every axis of an array must be explicitly indexed or an ellipsis
+ included. This behaviour is sometimes referred to as flat indexing.
+
+ - The start and stop of a slice may not be out of bounds. In
+ particular, for a slice ``i:j:k`` on an axis of size ``n``, only the
+ following are allowed:
+
+ - ``i`` or ``j`` omitted (``None``).
+ - ``-n <= i <= max(0, n - 1)``.
+ - For ``k > 0`` or ``k`` omitted (``None``), ``-n <= j <= n``.
+ - For ``k < 0``, ``-n - 1 <= j <= max(0, n - 1)``.
+
+ - Boolean array indices are not allowed as part of a larger tuple
+ index.
+
+ - Integer array indices are not allowed (with the exception of 0-D
+ arrays, which are treated the same as scalars).
+
+ Additionally, it should be noted that indices that would return a
+ scalar in NumPy will return a 0-D array. Array scalars are not allowed
+ in the specification, only 0-D arrays. This is done in the
+ ``Array._new`` constructor, not this function.
+
+ """
+ _key = key if isinstance(key, tuple) else (key,)
+ for i in _key:
+ if isinstance(i, bool) or not (
+ isinstance(i, SupportsIndex) # i.e. ints
+ or isinstance(i, slice)
+ or i == Ellipsis
+ or i is None
+ or isinstance(i, Array)
+ or isinstance(i, np.ndarray)
+ ):
+ raise IndexError(
+ f"Single-axes index {i} has {type(i)=}, but only "
+ "integers, slices (:), ellipsis (...), newaxis (None), "
+ "zero-dimensional integer arrays and boolean arrays "
+ "are specified in the Array API."
+ )
+
+ nonexpanding_key = []
+ single_axes = []
+ n_ellipsis = 0
+ key_has_mask = False
+ for i in _key:
+ if i is not None:
+ nonexpanding_key.append(i)
+ if isinstance(i, Array) or isinstance(i, np.ndarray):
+ if i.dtype in _boolean_dtypes:
+ key_has_mask = True
+ single_axes.append(i)
+ else:
+ # i must not be an array here, to avoid elementwise equals
+ if i == Ellipsis:
+ n_ellipsis += 1
+ else:
+ single_axes.append(i)
+
+ n_single_axes = len(single_axes)
+ if n_ellipsis > 1:
+ return # handled by ndarray
+ elif n_ellipsis == 0:
+ # Note boolean masks must be the sole index, which we check for
+ # later on.
+ if not key_has_mask and n_single_axes < self.ndim:
+ raise IndexError(
+ f"{self.ndim=}, but the multi-axes index only specifies "
+ f"{n_single_axes} dimensions. If this was intentional, "
+ "add a trailing ellipsis (...) which expands into as many "
+ "slices (:) as necessary - this is what np.ndarray arrays "
+ "implicitly do, but such flat indexing behaviour is not "
+ "specified in the Array API."
+ )
+
+ if n_ellipsis == 0:
+ indexed_shape = self.shape
+ else:
+ ellipsis_start = None
+ for pos, i in enumerate(nonexpanding_key):
+ if not (isinstance(i, Array) or isinstance(i, np.ndarray)):
+ if i == Ellipsis:
+ ellipsis_start = pos
+ break
+ assert ellipsis_start is not None # sanity check
+ ellipsis_end = self.ndim - (n_single_axes - ellipsis_start)
+ indexed_shape = (
+ self.shape[:ellipsis_start] + self.shape[ellipsis_end:]
+ )
+ for i, side in zip(single_axes, indexed_shape):
+ if isinstance(i, slice):
+ if side == 0:
+ f_range = "0 (or None)"
+ else:
+ f_range = f"between -{side} and {side - 1} (or None)"
+ if i.start is not None:
+ try:
+ start = operator.index(i.start)
+ except TypeError:
+ pass # handled by ndarray
+ else:
+ if not (-side <= start <= side):
+ raise IndexError(
+ f"Slice {i} contains {start=}, but should be "
+ f"{f_range} for an axis of size {side} "
+ "(out-of-bounds starts are not specified in "
+ "the Array API)"
+ )
+ if i.stop is not None:
+ try:
+ stop = operator.index(i.stop)
+ except TypeError:
+ pass # handled by ndarray
+ else:
+ if not (-side <= stop <= side):
+ raise IndexError(
+ f"Slice {i} contains {stop=}, but should be "
+ f"{f_range} for an axis of size {side} "
+ "(out-of-bounds stops are not specified in "
+ "the Array API)"
+ )
+ elif isinstance(i, Array):
+ if i.dtype in _boolean_dtypes and len(_key) != 1:
+ assert isinstance(key, tuple) # sanity check
+ raise IndexError(
+ f"Single-axes index {i} is a boolean array and "
+ f"{len(key)=}, but masking is only specified in the "
+ "Array API when the array is the sole index."
+ )
+ elif i.dtype in _integer_dtypes and i.ndim != 0:
+ raise IndexError(
+ f"Single-axes index {i} is a non-zero-dimensional "
+ "integer array, but advanced integer indexing is not "
+ "specified in the Array API."
+ )
+ elif isinstance(i, tuple):
+ raise IndexError(
+ f"Single-axes index {i} is a tuple, but nested tuple "
+ "indices are not specified in the Array API."
+ )
+
+ # Everything below this line is required by the spec.
+
+ def __abs__(self: Array, /) -> Array:
+ """
+ Performs the operation __abs__.
+ """
+ if self.dtype not in _numeric_dtypes:
+ raise TypeError("Only numeric dtypes are allowed in __abs__")
+ res = self._array.__abs__()
+ return self.__class__._new(res)
+
+ def __add__(self: Array, other: Union[int, float, Array], /) -> Array:
+ """
+ Performs the operation __add__.
+ """
+ other = self._check_allowed_dtypes(other, "numeric", "__add__")
+ if other is NotImplemented:
+ return other
+ self, other = self._normalize_two_args(self, other)
+ res = self._array.__add__(other._array)
+ return self.__class__._new(res)
+
+ def __and__(self: Array, other: Union[int, bool, Array], /) -> Array:
+ """
+ Performs the operation __and__.
+ """
+ other = self._check_allowed_dtypes(other, "integer or boolean", "__and__")
+ if other is NotImplemented:
+ return other
+ self, other = self._normalize_two_args(self, other)
+ res = self._array.__and__(other._array)
+ return self.__class__._new(res)
+
+ def __array_namespace__(
+ self: Array, /, *, api_version: Optional[str] = None
+ ) -> types.ModuleType:
+ if api_version is not None and not api_version.startswith("2021."):
+ raise ValueError(f"Unrecognized array API version: {api_version!r}")
+ return array_api
+
+ def __bool__(self: Array, /) -> bool:
+ """
+ Performs the operation __bool__.
+ """
+ # Note: This is an error here.
+ if self._array.ndim != 0:
+ raise TypeError("bool is only allowed on arrays with 0 dimensions")
+ if self.dtype not in _boolean_dtypes:
+ raise ValueError("bool is only allowed on boolean arrays")
+ res = self._array.__bool__()
+ return res
+
+ def __dlpack__(self: Array, /, *, stream: None = None) -> PyCapsule:
+ """
+ Performs the operation __dlpack__.
+ """
+ return self._array.__dlpack__(stream=stream)
+
+ def __dlpack_device__(self: Array, /) -> Tuple[IntEnum, int]:
+ """
+ Performs the operation __dlpack_device__.
+ """
+ # Note: device support is required for this
+ return self._array.__dlpack_device__()
+
+ def __eq__(self: Array, other: Union[int, float, bool, Array], /) -> Array:
+ """
+ Performs the operation __eq__.
+ """
+ # Even though "all" dtypes are allowed, we still require them to be
+ # promotable with each other.
+ other = self._check_allowed_dtypes(other, "all", "__eq__")
+ if other is NotImplemented:
+ return other
+ self, other = self._normalize_two_args(self, other)
+ res = self._array.__eq__(other._array)
+ return self.__class__._new(res)
+
+ def __float__(self: Array, /) -> float:
+ """
+ Performs the operation __float__.
+ """
+ # Note: This is an error here.
+ if self._array.ndim != 0:
+ raise TypeError("float is only allowed on arrays with 0 dimensions")
+ if self.dtype not in _floating_dtypes:
+ raise ValueError("float is only allowed on floating-point arrays")
+ res = self._array.__float__()
+ return res
+
+ def __floordiv__(self: Array, other: Union[int, float, Array], /) -> Array:
+ """
+ Performs the operation __floordiv__.
+ """
+ other = self._check_allowed_dtypes(other, "numeric", "__floordiv__")
+ if other is NotImplemented:
+ return other
+ self, other = self._normalize_two_args(self, other)
+ res = self._array.__floordiv__(other._array)
+ return self.__class__._new(res)
+
+ def __ge__(self: Array, other: Union[int, float, Array], /) -> Array:
+ """
+ Performs the operation __ge__.
+ """
+ other = self._check_allowed_dtypes(other, "numeric", "__ge__")
+ if other is NotImplemented:
+ return other
+ self, other = self._normalize_two_args(self, other)
+ res = self._array.__ge__(other._array)
+ return self.__class__._new(res)
+
+ def __getitem__(
+ self: Array,
+ key: Union[
+ int, slice, ellipsis, Tuple[Union[int, slice, ellipsis], ...], Array
+ ],
+ /,
+ ) -> Array:
+ """
+ Performs the operation __getitem__.
+ """
+ # Note: Only indices required by the spec are allowed. See the
+ # docstring of _validate_index
+ self._validate_index(key)
+ if isinstance(key, Array):
+ # Indexing self._array with array_api arrays can be erroneous
+ key = key._array
+ res = self._array.__getitem__(key)
+ return self._new(res)
+
+ def __gt__(self: Array, other: Union[int, float, Array], /) -> Array:
+ """
+ Performs the operation __gt__.
+ """
+ other = self._check_allowed_dtypes(other, "numeric", "__gt__")
+ if other is NotImplemented:
+ return other
+ self, other = self._normalize_two_args(self, other)
+ res = self._array.__gt__(other._array)
+ return self.__class__._new(res)
+
+ def __int__(self: Array, /) -> int:
+ """
+ Performs the operation __int__.
+ """
+ # Note: This is an error here.
+ if self._array.ndim != 0:
+ raise TypeError("int is only allowed on arrays with 0 dimensions")
+ if self.dtype not in _integer_dtypes:
+ raise ValueError("int is only allowed on integer arrays")
+ res = self._array.__int__()
+ return res
+
+ def __index__(self: Array, /) -> int:
+ """
+ Performs the operation __index__.
+ """
+ res = self._array.__index__()
+ return res
+
+ def __invert__(self: Array, /) -> Array:
+ """
+ Performs the operation __invert__.
+ """
+ if self.dtype not in _integer_or_boolean_dtypes:
+ raise TypeError("Only integer or boolean dtypes are allowed in __invert__")
+ res = self._array.__invert__()
+ return self.__class__._new(res)
+
+ def __le__(self: Array, other: Union[int, float, Array], /) -> Array:
+ """
+ Performs the operation __le__.
+ """
+ other = self._check_allowed_dtypes(other, "numeric", "__le__")
+ if other is NotImplemented:
+ return other
+ self, other = self._normalize_two_args(self, other)
+ res = self._array.__le__(other._array)
+ return self.__class__._new(res)
+
+ def __lshift__(self: Array, other: Union[int, Array], /) -> Array:
+ """
+ Performs the operation __lshift__.
+ """
+ other = self._check_allowed_dtypes(other, "integer", "__lshift__")
+ if other is NotImplemented:
+ return other
+ self, other = self._normalize_two_args(self, other)
+ res = self._array.__lshift__(other._array)
+ return self.__class__._new(res)
+
+ def __lt__(self: Array, other: Union[int, float, Array], /) -> Array:
+ """
+ Performs the operation __lt__.
+ """
+ other = self._check_allowed_dtypes(other, "numeric", "__lt__")
+ if other is NotImplemented:
+ return other
+ self, other = self._normalize_two_args(self, other)
+ res = self._array.__lt__(other._array)
+ return self.__class__._new(res)
+
+ def __matmul__(self: Array, other: Array, /) -> Array:
+ """
+ Performs the operation __matmul__.
+ """
+ # matmul is not defined for scalars, but without this, we may get
+ # the wrong error message from asarray.
+ other = self._check_allowed_dtypes(other, "numeric", "__matmul__")
+ if other is NotImplemented:
+ return other
+ res = self._array.__matmul__(other._array)
+ return self.__class__._new(res)
+
+ def __mod__(self: Array, other: Union[int, float, Array], /) -> Array:
+ """
+ Performs the operation __mod__.
+ """
+ other = self._check_allowed_dtypes(other, "numeric", "__mod__")
+ if other is NotImplemented:
+ return other
+ self, other = self._normalize_two_args(self, other)
+ res = self._array.__mod__(other._array)
+ return self.__class__._new(res)
+
+ def __mul__(self: Array, other: Union[int, float, Array], /) -> Array:
+ """
+ Performs the operation __mul__.
+ """
+ other = self._check_allowed_dtypes(other, "numeric", "__mul__")
+ if other is NotImplemented:
+ return other
+ self, other = self._normalize_two_args(self, other)
+ res = self._array.__mul__(other._array)
+ return self.__class__._new(res)
+
+ def __ne__(self: Array, other: Union[int, float, bool, Array], /) -> Array:
+ """
+ Performs the operation __ne__.
+ """
+ other = self._check_allowed_dtypes(other, "all", "__ne__")
+ if other is NotImplemented:
+ return other
+ self, other = self._normalize_two_args(self, other)
+ res = self._array.__ne__(other._array)
+ return self.__class__._new(res)
+
+ def __neg__(self: Array, /) -> Array:
+ """
+ Performs the operation __neg__.
+ """
+ if self.dtype not in _numeric_dtypes:
+ raise TypeError("Only numeric dtypes are allowed in __neg__")
+ res = self._array.__neg__()
+ return self.__class__._new(res)
+
+ def __or__(self: Array, other: Union[int, bool, Array], /) -> Array:
+ """
+ Performs the operation __or__.
+ """
+ other = self._check_allowed_dtypes(other, "integer or boolean", "__or__")
+ if other is NotImplemented:
+ return other
+ self, other = self._normalize_two_args(self, other)
+ res = self._array.__or__(other._array)
+ return self.__class__._new(res)
+
+ def __pos__(self: Array, /) -> Array:
+ """
+ Performs the operation __pos__.
+ """
+ if self.dtype not in _numeric_dtypes:
+ raise TypeError("Only numeric dtypes are allowed in __pos__")
+ res = self._array.__pos__()
+ return self.__class__._new(res)
+
+ def __pow__(self: Array, other: Union[int, float, Array], /) -> Array:
+ """
+ Performs the operation __pow__.
+ """
+ from ._elementwise_functions import pow
+
+ other = self._check_allowed_dtypes(other, "numeric", "__pow__")
+ if other is NotImplemented:
+ return other
+ # Note: NumPy's __pow__ does not follow type promotion rules for 0-d
+ # arrays, so we use pow() here instead.
+ return pow(self, other)
+
+ def __rshift__(self: Array, other: Union[int, Array], /) -> Array:
+ """
+ Performs the operation __rshift__.
+ """
+ other = self._check_allowed_dtypes(other, "integer", "__rshift__")
+ if other is NotImplemented:
+ return other
+ self, other = self._normalize_two_args(self, other)
+ res = self._array.__rshift__(other._array)
+ return self.__class__._new(res)
+
+ def __setitem__(
+ self,
+ key: Union[
+ int, slice, ellipsis, Tuple[Union[int, slice, ellipsis], ...], Array
+ ],
+ value: Union[int, float, bool, Array],
+ /,
+ ) -> None:
+ """
+ Performs the operation __setitem__.
+ """
+ # Note: Only indices required by the spec are allowed. See the
+ # docstring of _validate_index
+ self._validate_index(key)
+ if isinstance(key, Array):
+ # Indexing self._array with array_api arrays can be erroneous
+ key = key._array
+ self._array.__setitem__(key, asarray(value)._array)
+
+ def __sub__(self: Array, other: Union[int, float, Array], /) -> Array:
+ """
+ Performs the operation __sub__.
+ """
+ other = self._check_allowed_dtypes(other, "numeric", "__sub__")
+ if other is NotImplemented:
+ return other
+ self, other = self._normalize_two_args(self, other)
+ res = self._array.__sub__(other._array)
+ return self.__class__._new(res)
+
+ # PEP 484 requires int to be a subtype of float, but __truediv__ should
+ # not accept int.
+ def __truediv__(self: Array, other: Union[float, Array], /) -> Array:
+ """
+ Performs the operation __truediv__.
+ """
+ other = self._check_allowed_dtypes(other, "floating-point", "__truediv__")
+ if other is NotImplemented:
+ return other
+ self, other = self._normalize_two_args(self, other)
+ res = self._array.__truediv__(other._array)
+ return self.__class__._new(res)
+
+ def __xor__(self: Array, other: Union[int, bool, Array], /) -> Array:
+ """
+ Performs the operation __xor__.
+ """
+ other = self._check_allowed_dtypes(other, "integer or boolean", "__xor__")
+ if other is NotImplemented:
+ return other
+ self, other = self._normalize_two_args(self, other)
+ res = self._array.__xor__(other._array)
+ return self.__class__._new(res)
+
+ def __iadd__(self: Array, other: Union[int, float, Array], /) -> Array:
+ """
+ Performs the operation __iadd__.
+ """
+ other = self._check_allowed_dtypes(other, "numeric", "__iadd__")
+ if other is NotImplemented:
+ return other
+ self._array.__iadd__(other._array)
+ return self
+
+ def __radd__(self: Array, other: Union[int, float, Array], /) -> Array:
+ """
+ Performs the operation __radd__.
+ """
+ other = self._check_allowed_dtypes(other, "numeric", "__radd__")
+ if other is NotImplemented:
+ return other
+ self, other = self._normalize_two_args(self, other)
+ res = self._array.__radd__(other._array)
+ return self.__class__._new(res)
+
+ def __iand__(self: Array, other: Union[int, bool, Array], /) -> Array:
+ """
+ Performs the operation __iand__.
+ """
+ other = self._check_allowed_dtypes(other, "integer or boolean", "__iand__")
+ if other is NotImplemented:
+ return other
+ self._array.__iand__(other._array)
+ return self
+
+ def __rand__(self: Array, other: Union[int, bool, Array], /) -> Array:
+ """
+ Performs the operation __rand__.
+ """
+ other = self._check_allowed_dtypes(other, "integer or boolean", "__rand__")
+ if other is NotImplemented:
+ return other
+ self, other = self._normalize_two_args(self, other)
+ res = self._array.__rand__(other._array)
+ return self.__class__._new(res)
+
+ def __ifloordiv__(self: Array, other: Union[int, float, Array], /) -> Array:
+ """
+ Performs the operation __ifloordiv__.
+ """
+ other = self._check_allowed_dtypes(other, "numeric", "__ifloordiv__")
+ if other is NotImplemented:
+ return other
+ self._array.__ifloordiv__(other._array)
+ return self
+
+ def __rfloordiv__(self: Array, other: Union[int, float, Array], /) -> Array:
+ """
+ Performs the operation __rfloordiv__.
+ """
+ other = self._check_allowed_dtypes(other, "numeric", "__rfloordiv__")
+ if other is NotImplemented:
+ return other
+ self, other = self._normalize_two_args(self, other)
+ res = self._array.__rfloordiv__(other._array)
+ return self.__class__._new(res)
+
+ def __ilshift__(self: Array, other: Union[int, Array], /) -> Array:
+ """
+ Performs the operation __ilshift__.
+ """
+ other = self._check_allowed_dtypes(other, "integer", "__ilshift__")
+ if other is NotImplemented:
+ return other
+ self._array.__ilshift__(other._array)
+ return self
+
+ def __rlshift__(self: Array, other: Union[int, Array], /) -> Array:
+ """
+ Performs the operation __rlshift__.
+ """
+ other = self._check_allowed_dtypes(other, "integer", "__rlshift__")
+ if other is NotImplemented:
+ return other
+ self, other = self._normalize_two_args(self, other)
+ res = self._array.__rlshift__(other._array)
+ return self.__class__._new(res)
+
+ def __imatmul__(self: Array, other: Array, /) -> Array:
+ """
+ Performs the operation __imatmul__.
+ """
+ # Note: NumPy does not implement __imatmul__.
+
+ # matmul is not defined for scalars, but without this, we may get
+ # the wrong error message from asarray.
+ other = self._check_allowed_dtypes(other, "numeric", "__imatmul__")
+ if other is NotImplemented:
+ return other
+
+ # __imatmul__ can only be allowed when it would not change the shape
+ # of self.
+ other_shape = other.shape
+ if self.shape == () or other_shape == ():
+ raise ValueError("@= requires at least one dimension")
+ if len(other_shape) == 1 or other_shape[-1] != other_shape[-2]:
+ raise ValueError("@= cannot change the shape of the input array")
+ self._array[:] = self._array.__matmul__(other._array)
+ return self
+
+ def __rmatmul__(self: Array, other: Array, /) -> Array:
+ """
+ Performs the operation __rmatmul__.
+ """
+ # matmul is not defined for scalars, but without this, we may get
+ # the wrong error message from asarray.
+ other = self._check_allowed_dtypes(other, "numeric", "__rmatmul__")
+ if other is NotImplemented:
+ return other
+ res = self._array.__rmatmul__(other._array)
+ return self.__class__._new(res)
+
+ def __imod__(self: Array, other: Union[int, float, Array], /) -> Array:
+ """
+ Performs the operation __imod__.
+ """
+ other = self._check_allowed_dtypes(other, "numeric", "__imod__")
+ if other is NotImplemented:
+ return other
+ self._array.__imod__(other._array)
+ return self
+
+ def __rmod__(self: Array, other: Union[int, float, Array], /) -> Array:
+ """
+ Performs the operation __rmod__.
+ """
+ other = self._check_allowed_dtypes(other, "numeric", "__rmod__")
+ if other is NotImplemented:
+ return other
+ self, other = self._normalize_two_args(self, other)
+ res = self._array.__rmod__(other._array)
+ return self.__class__._new(res)
+
+ def __imul__(self: Array, other: Union[int, float, Array], /) -> Array:
+ """
+ Performs the operation __imul__.
+ """
+ other = self._check_allowed_dtypes(other, "numeric", "__imul__")
+ if other is NotImplemented:
+ return other
+ self._array.__imul__(other._array)
+ return self
+
+ def __rmul__(self: Array, other: Union[int, float, Array], /) -> Array:
+ """
+ Performs the operation __rmul__.
+ """
+ other = self._check_allowed_dtypes(other, "numeric", "__rmul__")
+ if other is NotImplemented:
+ return other
+ self, other = self._normalize_two_args(self, other)
+ res = self._array.__rmul__(other._array)
+ return self.__class__._new(res)
+
+ def __ior__(self: Array, other: Union[int, bool, Array], /) -> Array:
+ """
+ Performs the operation __ior__.
+ """
+ other = self._check_allowed_dtypes(other, "integer or boolean", "__ior__")
+ if other is NotImplemented:
+ return other
+ self._array.__ior__(other._array)
+ return self
+
+ def __ror__(self: Array, other: Union[int, bool, Array], /) -> Array:
+ """
+ Performs the operation __ror__.
+ """
+ other = self._check_allowed_dtypes(other, "integer or boolean", "__ror__")
+ if other is NotImplemented:
+ return other
+ self, other = self._normalize_two_args(self, other)
+ res = self._array.__ror__(other._array)
+ return self.__class__._new(res)
+
+ def __ipow__(self: Array, other: Union[int, float, Array], /) -> Array:
+ """
+ Performs the operation __ipow__.
+ """
+ other = self._check_allowed_dtypes(other, "numeric", "__ipow__")
+ if other is NotImplemented:
+ return other
+ self._array.__ipow__(other._array)
+ return self
+
+ def __rpow__(self: Array, other: Union[int, float, Array], /) -> Array:
+ """
+ Performs the operation __rpow__.
+ """
+ from ._elementwise_functions import pow
+
+ other = self._check_allowed_dtypes(other, "numeric", "__rpow__")
+ if other is NotImplemented:
+ return other
+ # Note: NumPy's __pow__ does not follow the spec type promotion rules
+ # for 0-d arrays, so we use pow() here instead.
+ return pow(other, self)
+
+ def __irshift__(self: Array, other: Union[int, Array], /) -> Array:
+ """
+ Performs the operation __irshift__.
+ """
+ other = self._check_allowed_dtypes(other, "integer", "__irshift__")
+ if other is NotImplemented:
+ return other
+ self._array.__irshift__(other._array)
+ return self
+
+ def __rrshift__(self: Array, other: Union[int, Array], /) -> Array:
+ """
+ Performs the operation __rrshift__.
+ """
+ other = self._check_allowed_dtypes(other, "integer", "__rrshift__")
+ if other is NotImplemented:
+ return other
+ self, other = self._normalize_two_args(self, other)
+ res = self._array.__rrshift__(other._array)
+ return self.__class__._new(res)
+
+ def __isub__(self: Array, other: Union[int, float, Array], /) -> Array:
+ """
+ Performs the operation __isub__.
+ """
+ other = self._check_allowed_dtypes(other, "numeric", "__isub__")
+ if other is NotImplemented:
+ return other
+ self._array.__isub__(other._array)
+ return self
+
+ def __rsub__(self: Array, other: Union[int, float, Array], /) -> Array:
+ """
+ Performs the operation __rsub__.
+ """
+ other = self._check_allowed_dtypes(other, "numeric", "__rsub__")
+ if other is NotImplemented:
+ return other
+ self, other = self._normalize_two_args(self, other)
+ res = self._array.__rsub__(other._array)
+ return self.__class__._new(res)
+
+ def __itruediv__(self: Array, other: Union[float, Array], /) -> Array:
+ """
+ Performs the operation __itruediv__.
+ """
+ other = self._check_allowed_dtypes(other, "floating-point", "__itruediv__")
+ if other is NotImplemented:
+ return other
+ self._array.__itruediv__(other._array)
+ return self
+
+ def __rtruediv__(self: Array, other: Union[float, Array], /) -> Array:
+ """
+ Performs the operation __rtruediv__.
+ """
+ other = self._check_allowed_dtypes(other, "floating-point", "__rtruediv__")
+ if other is NotImplemented:
+ return other
+ self, other = self._normalize_two_args(self, other)
+ res = self._array.__rtruediv__(other._array)
+ return self.__class__._new(res)
+
+ def __ixor__(self: Array, other: Union[int, bool, Array], /) -> Array:
+ """
+ Performs the operation __ixor__.
+ """
+ other = self._check_allowed_dtypes(other, "integer or boolean", "__ixor__")
+ if other is NotImplemented:
+ return other
+ self._array.__ixor__(other._array)
+ return self
+
+ def __rxor__(self: Array, other: Union[int, bool, Array], /) -> Array:
+ """
+ Performs the operation __rxor__.
+ """
+ other = self._check_allowed_dtypes(other, "integer or boolean", "__rxor__")
+ if other is NotImplemented:
+ return other
+ self, other = self._normalize_two_args(self, other)
+ res = self._array.__rxor__(other._array)
+ return self.__class__._new(res)
+
+ def to_device(self: Array, device: Device, /, stream: None = None) -> Array:
+ if stream is not None:
+ raise ValueError("The stream argument to to_device() is not supported")
+ if device == 'cpu':
+ return self
+ raise ValueError(f"Unsupported device {device!r}")
+
+ @property
+ def dtype(self) -> Dtype:
+ """
+ Array API compatible wrapper for :py:meth:`np.ndarray.dtype <numpy.ndarray.dtype>`.
+
+ See its docstring for more information.
+ """
+ return self._array.dtype
+
+ @property
+ def device(self) -> Device:
+ return "cpu"
+
+ # Note: mT is new in array API spec (see matrix_transpose)
+ @property
+ def mT(self) -> Array:
+ from .linalg import matrix_transpose
+ return matrix_transpose(self)
+
+ @property
+ def ndim(self) -> int:
+ """
+ Array API compatible wrapper for :py:meth:`np.ndarray.ndim <numpy.ndarray.ndim>`.
+
+ See its docstring for more information.
+ """
+ return self._array.ndim
+
+ @property
+ def shape(self) -> Tuple[int, ...]:
+ """
+ Array API compatible wrapper for :py:meth:`np.ndarray.shape <numpy.ndarray.shape>`.
+
+ See its docstring for more information.
+ """
+ return self._array.shape
+
+ @property
+ def size(self) -> int:
+ """
+ Array API compatible wrapper for :py:meth:`np.ndarray.size <numpy.ndarray.size>`.
+
+ See its docstring for more information.
+ """
+ return self._array.size
+
+ @property
+ def T(self) -> Array:
+ """
+ Array API compatible wrapper for :py:meth:`np.ndarray.T <numpy.ndarray.T>`.
+
+ See its docstring for more information.
+ """
+ # Note: T only works on 2-dimensional arrays. See the corresponding
+ # note in the specification:
+ # https://data-apis.org/array-api/latest/API_specification/array_object.html#t
+ if self.ndim != 2:
+ raise ValueError("x.T requires x to have 2 dimensions. Use x.mT to transpose stacks of matrices and permute_dims() to permute dimensions.")
+ return self.__class__._new(self._array.T)
diff --git a/venv/lib/python3.9/site-packages/numpy/array_api/_constants.py b/venv/lib/python3.9/site-packages/numpy/array_api/_constants.py
new file mode 100644
index 00000000..9541941e
--- /dev/null
+++ b/venv/lib/python3.9/site-packages/numpy/array_api/_constants.py
@@ -0,0 +1,6 @@
+import numpy as np
+
+e = np.e
+inf = np.inf
+nan = np.nan
+pi = np.pi
diff --git a/venv/lib/python3.9/site-packages/numpy/array_api/_creation_functions.py b/venv/lib/python3.9/site-packages/numpy/array_api/_creation_functions.py
new file mode 100644
index 00000000..3b014d37
--- /dev/null
+++ b/venv/lib/python3.9/site-packages/numpy/array_api/_creation_functions.py
@@ -0,0 +1,351 @@
+from __future__ import annotations
+
+
+from typing import TYPE_CHECKING, List, Optional, Tuple, Union
+
+if TYPE_CHECKING:
+ from ._typing import (
+ Array,
+ Device,
+ Dtype,
+ NestedSequence,
+ SupportsBufferProtocol,
+ )
+ from collections.abc import Sequence
+from ._dtypes import _all_dtypes
+
+import numpy as np
+
+
+def _check_valid_dtype(dtype):
+ # Note: Only spelling dtypes as the dtype objects is supported.
+
+ # We use this instead of "dtype in _all_dtypes" because the dtype objects
+ # define equality with the sorts of things we want to disallow.
+ for d in (None,) + _all_dtypes:
+ if dtype is d:
+ return
+ raise ValueError("dtype must be one of the supported dtypes")
+
+
+def asarray(
+ obj: Union[
+ Array,
+ bool,
+ int,
+ float,
+ NestedSequence[bool | int | float],
+ SupportsBufferProtocol,
+ ],
+ /,
+ *,
+ dtype: Optional[Dtype] = None,
+ device: Optional[Device] = None,
+ copy: Optional[Union[bool, np._CopyMode]] = None,
+) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.asarray <numpy.asarray>`.
+
+ See its docstring for more information.
+ """
+ # _array_object imports in this file are inside the functions to avoid
+ # circular imports
+ from ._array_object import Array
+
+ _check_valid_dtype(dtype)
+ if device not in ["cpu", None]:
+ raise ValueError(f"Unsupported device {device!r}")
+ if copy in (False, np._CopyMode.IF_NEEDED):
+ # Note: copy=False is not yet implemented in np.asarray
+ raise NotImplementedError("copy=False is not yet implemented")
+ if isinstance(obj, Array):
+ if dtype is not None and obj.dtype != dtype:
+ copy = True
+ if copy in (True, np._CopyMode.ALWAYS):
+ return Array._new(np.array(obj._array, copy=True, dtype=dtype))
+ return obj
+ if dtype is None and isinstance(obj, int) and (obj > 2 ** 64 or obj < -(2 ** 63)):
+ # Give a better error message in this case. NumPy would convert this
+ # to an object array. TODO: This won't handle large integers in lists.
+ raise OverflowError("Integer out of bounds for array dtypes")
+ res = np.asarray(obj, dtype=dtype)
+ return Array._new(res)
+
+
+def arange(
+ start: Union[int, float],
+ /,
+ stop: Optional[Union[int, float]] = None,
+ step: Union[int, float] = 1,
+ *,
+ dtype: Optional[Dtype] = None,
+ device: Optional[Device] = None,
+) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.arange <numpy.arange>`.
+
+ See its docstring for more information.
+ """
+ from ._array_object import Array
+
+ _check_valid_dtype(dtype)
+ if device not in ["cpu", None]:
+ raise ValueError(f"Unsupported device {device!r}")
+ return Array._new(np.arange(start, stop=stop, step=step, dtype=dtype))
+
+
+def empty(
+ shape: Union[int, Tuple[int, ...]],
+ *,
+ dtype: Optional[Dtype] = None,
+ device: Optional[Device] = None,
+) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.empty <numpy.empty>`.
+
+ See its docstring for more information.
+ """
+ from ._array_object import Array
+
+ _check_valid_dtype(dtype)
+ if device not in ["cpu", None]:
+ raise ValueError(f"Unsupported device {device!r}")
+ return Array._new(np.empty(shape, dtype=dtype))
+
+
+def empty_like(
+ x: Array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None
+) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.empty_like <numpy.empty_like>`.
+
+ See its docstring for more information.
+ """
+ from ._array_object import Array
+
+ _check_valid_dtype(dtype)
+ if device not in ["cpu", None]:
+ raise ValueError(f"Unsupported device {device!r}")
+ return Array._new(np.empty_like(x._array, dtype=dtype))
+
+
+def eye(
+ n_rows: int,
+ n_cols: Optional[int] = None,
+ /,
+ *,
+ k: int = 0,
+ dtype: Optional[Dtype] = None,
+ device: Optional[Device] = None,
+) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.eye <numpy.eye>`.
+
+ See its docstring for more information.
+ """
+ from ._array_object import Array
+
+ _check_valid_dtype(dtype)
+ if device not in ["cpu", None]:
+ raise ValueError(f"Unsupported device {device!r}")
+ return Array._new(np.eye(n_rows, M=n_cols, k=k, dtype=dtype))
+
+
+def from_dlpack(x: object, /) -> Array:
+ from ._array_object import Array
+
+ return Array._new(np.from_dlpack(x))
+
+
+def full(
+ shape: Union[int, Tuple[int, ...]],
+ fill_value: Union[int, float],
+ *,
+ dtype: Optional[Dtype] = None,
+ device: Optional[Device] = None,
+) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.full <numpy.full>`.
+
+ See its docstring for more information.
+ """
+ from ._array_object import Array
+
+ _check_valid_dtype(dtype)
+ if device not in ["cpu", None]:
+ raise ValueError(f"Unsupported device {device!r}")
+ if isinstance(fill_value, Array) and fill_value.ndim == 0:
+ fill_value = fill_value._array
+ res = np.full(shape, fill_value, dtype=dtype)
+ if res.dtype not in _all_dtypes:
+ # This will happen if the fill value is not something that NumPy
+ # coerces to one of the acceptable dtypes.
+ raise TypeError("Invalid input to full")
+ return Array._new(res)
+
+
+def full_like(
+ x: Array,
+ /,
+ fill_value: Union[int, float],
+ *,
+ dtype: Optional[Dtype] = None,
+ device: Optional[Device] = None,
+) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.full_like <numpy.full_like>`.
+
+ See its docstring for more information.
+ """
+ from ._array_object import Array
+
+ _check_valid_dtype(dtype)
+ if device not in ["cpu", None]:
+ raise ValueError(f"Unsupported device {device!r}")
+ res = np.full_like(x._array, fill_value, dtype=dtype)
+ if res.dtype not in _all_dtypes:
+ # This will happen if the fill value is not something that NumPy
+ # coerces to one of the acceptable dtypes.
+ raise TypeError("Invalid input to full_like")
+ return Array._new(res)
+
+
+def linspace(
+ start: Union[int, float],
+ stop: Union[int, float],
+ /,
+ num: int,
+ *,
+ dtype: Optional[Dtype] = None,
+ device: Optional[Device] = None,
+ endpoint: bool = True,
+) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.linspace <numpy.linspace>`.
+
+ See its docstring for more information.
+ """
+ from ._array_object import Array
+
+ _check_valid_dtype(dtype)
+ if device not in ["cpu", None]:
+ raise ValueError(f"Unsupported device {device!r}")
+ return Array._new(np.linspace(start, stop, num, dtype=dtype, endpoint=endpoint))
+
+
+def meshgrid(*arrays: Array, indexing: str = "xy") -> List[Array]:
+ """
+ Array API compatible wrapper for :py:func:`np.meshgrid <numpy.meshgrid>`.
+
+ See its docstring for more information.
+ """
+ from ._array_object import Array
+
+ # Note: unlike np.meshgrid, only inputs with all the same dtype are
+ # allowed
+
+ if len({a.dtype for a in arrays}) > 1:
+ raise ValueError("meshgrid inputs must all have the same dtype")
+
+ return [
+ Array._new(array)
+ for array in np.meshgrid(*[a._array for a in arrays], indexing=indexing)
+ ]
+
+
+def ones(
+ shape: Union[int, Tuple[int, ...]],
+ *,
+ dtype: Optional[Dtype] = None,
+ device: Optional[Device] = None,
+) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.ones <numpy.ones>`.
+
+ See its docstring for more information.
+ """
+ from ._array_object import Array
+
+ _check_valid_dtype(dtype)
+ if device not in ["cpu", None]:
+ raise ValueError(f"Unsupported device {device!r}")
+ return Array._new(np.ones(shape, dtype=dtype))
+
+
+def ones_like(
+ x: Array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None
+) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.ones_like <numpy.ones_like>`.
+
+ See its docstring for more information.
+ """
+ from ._array_object import Array
+
+ _check_valid_dtype(dtype)
+ if device not in ["cpu", None]:
+ raise ValueError(f"Unsupported device {device!r}")
+ return Array._new(np.ones_like(x._array, dtype=dtype))
+
+
+def tril(x: Array, /, *, k: int = 0) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.tril <numpy.tril>`.
+
+ See its docstring for more information.
+ """
+ from ._array_object import Array
+
+ if x.ndim < 2:
+ # Note: Unlike np.tril, x must be at least 2-D
+ raise ValueError("x must be at least 2-dimensional for tril")
+ return Array._new(np.tril(x._array, k=k))
+
+
+def triu(x: Array, /, *, k: int = 0) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.triu <numpy.triu>`.
+
+ See its docstring for more information.
+ """
+ from ._array_object import Array
+
+ if x.ndim < 2:
+ # Note: Unlike np.triu, x must be at least 2-D
+ raise ValueError("x must be at least 2-dimensional for triu")
+ return Array._new(np.triu(x._array, k=k))
+
+
+def zeros(
+ shape: Union[int, Tuple[int, ...]],
+ *,
+ dtype: Optional[Dtype] = None,
+ device: Optional[Device] = None,
+) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.zeros <numpy.zeros>`.
+
+ See its docstring for more information.
+ """
+ from ._array_object import Array
+
+ _check_valid_dtype(dtype)
+ if device not in ["cpu", None]:
+ raise ValueError(f"Unsupported device {device!r}")
+ return Array._new(np.zeros(shape, dtype=dtype))
+
+
+def zeros_like(
+ x: Array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None
+) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.zeros_like <numpy.zeros_like>`.
+
+ See its docstring for more information.
+ """
+ from ._array_object import Array
+
+ _check_valid_dtype(dtype)
+ if device not in ["cpu", None]:
+ raise ValueError(f"Unsupported device {device!r}")
+ return Array._new(np.zeros_like(x._array, dtype=dtype))
diff --git a/venv/lib/python3.9/site-packages/numpy/array_api/_data_type_functions.py b/venv/lib/python3.9/site-packages/numpy/array_api/_data_type_functions.py
new file mode 100644
index 00000000..7026bd48
--- /dev/null
+++ b/venv/lib/python3.9/site-packages/numpy/array_api/_data_type_functions.py
@@ -0,0 +1,146 @@
+from __future__ import annotations
+
+from ._array_object import Array
+from ._dtypes import _all_dtypes, _result_type
+
+from dataclasses import dataclass
+from typing import TYPE_CHECKING, List, Tuple, Union
+
+if TYPE_CHECKING:
+ from ._typing import Dtype
+ from collections.abc import Sequence
+
+import numpy as np
+
+
+# Note: astype is a function, not an array method as in NumPy.
+def astype(x: Array, dtype: Dtype, /, *, copy: bool = True) -> Array:
+ if not copy and dtype == x.dtype:
+ return x
+ return Array._new(x._array.astype(dtype=dtype, copy=copy))
+
+
+def broadcast_arrays(*arrays: Array) -> List[Array]:
+ """
+ Array API compatible wrapper for :py:func:`np.broadcast_arrays <numpy.broadcast_arrays>`.
+
+ See its docstring for more information.
+ """
+ from ._array_object import Array
+
+ return [
+ Array._new(array) for array in np.broadcast_arrays(*[a._array for a in arrays])
+ ]
+
+
+def broadcast_to(x: Array, /, shape: Tuple[int, ...]) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.broadcast_to <numpy.broadcast_to>`.
+
+ See its docstring for more information.
+ """
+ from ._array_object import Array
+
+ return Array._new(np.broadcast_to(x._array, shape))
+
+
+def can_cast(from_: Union[Dtype, Array], to: Dtype, /) -> bool:
+ """
+ Array API compatible wrapper for :py:func:`np.can_cast <numpy.can_cast>`.
+
+ See its docstring for more information.
+ """
+ if isinstance(from_, Array):
+ from_ = from_.dtype
+ elif from_ not in _all_dtypes:
+ raise TypeError(f"{from_=}, but should be an array_api array or dtype")
+ if to not in _all_dtypes:
+ raise TypeError(f"{to=}, but should be a dtype")
+ # Note: We avoid np.can_cast() as it has discrepancies with the array API,
+ # since NumPy allows cross-kind casting (e.g., NumPy allows bool -> int8).
+ # See https://github.com/numpy/numpy/issues/20870
+ try:
+ # We promote `from_` and `to` together. We then check if the promoted
+ # dtype is `to`, which indicates if `from_` can (up)cast to `to`.
+ dtype = _result_type(from_, to)
+ return to == dtype
+ except TypeError:
+ # _result_type() raises if the dtypes don't promote together
+ return False
+
+
+# These are internal objects for the return types of finfo and iinfo, since
+# the NumPy versions contain extra data that isn't part of the spec.
+@dataclass
+class finfo_object:
+ bits: int
+ # Note: The types of the float data here are float, whereas in NumPy they
+ # are scalars of the corresponding float dtype.
+ eps: float
+ max: float
+ min: float
+ smallest_normal: float
+
+
+@dataclass
+class iinfo_object:
+ bits: int
+ max: int
+ min: int
+
+
+def finfo(type: Union[Dtype, Array], /) -> finfo_object:
+ """
+ Array API compatible wrapper for :py:func:`np.finfo <numpy.finfo>`.
+
+ See its docstring for more information.
+ """
+ fi = np.finfo(type)
+ # Note: The types of the float data here are float, whereas in NumPy they
+ # are scalars of the corresponding float dtype.
+ return finfo_object(
+ fi.bits,
+ float(fi.eps),
+ float(fi.max),
+ float(fi.min),
+ float(fi.smallest_normal),
+ )
+
+
+def iinfo(type: Union[Dtype, Array], /) -> iinfo_object:
+ """
+ Array API compatible wrapper for :py:func:`np.iinfo <numpy.iinfo>`.
+
+ See its docstring for more information.
+ """
+ ii = np.iinfo(type)
+ return iinfo_object(ii.bits, ii.max, ii.min)
+
+
+def result_type(*arrays_and_dtypes: Union[Array, Dtype]) -> Dtype:
+ """
+ Array API compatible wrapper for :py:func:`np.result_type <numpy.result_type>`.
+
+ See its docstring for more information.
+ """
+ # Note: we use a custom implementation that gives only the type promotions
+ # required by the spec rather than using np.result_type. NumPy implements
+ # too many extra type promotions like int64 + uint64 -> float64, and does
+ # value-based casting on scalar arrays.
+ A = []
+ for a in arrays_and_dtypes:
+ if isinstance(a, Array):
+ a = a.dtype
+ elif isinstance(a, np.ndarray) or a not in _all_dtypes:
+ raise TypeError("result_type() inputs must be array_api arrays or dtypes")
+ A.append(a)
+
+ if len(A) == 0:
+ raise ValueError("at least one array or dtype is required")
+ elif len(A) == 1:
+ return A[0]
+ else:
+ t = A[0]
+ for t2 in A[1:]:
+ t = _result_type(t, t2)
+ return t
diff --git a/venv/lib/python3.9/site-packages/numpy/array_api/_dtypes.py b/venv/lib/python3.9/site-packages/numpy/array_api/_dtypes.py
new file mode 100644
index 00000000..476d619f
--- /dev/null
+++ b/venv/lib/python3.9/site-packages/numpy/array_api/_dtypes.py
@@ -0,0 +1,143 @@
+import numpy as np
+
+# Note: we use dtype objects instead of dtype classes. The spec does not
+# require any behavior on dtypes other than equality.
+int8 = np.dtype("int8")
+int16 = np.dtype("int16")
+int32 = np.dtype("int32")
+int64 = np.dtype("int64")
+uint8 = np.dtype("uint8")
+uint16 = np.dtype("uint16")
+uint32 = np.dtype("uint32")
+uint64 = np.dtype("uint64")
+float32 = np.dtype("float32")
+float64 = np.dtype("float64")
+# Note: This name is changed
+bool = np.dtype("bool")
+
+_all_dtypes = (
+ int8,
+ int16,
+ int32,
+ int64,
+ uint8,
+ uint16,
+ uint32,
+ uint64,
+ float32,
+ float64,
+ bool,
+)
+_boolean_dtypes = (bool,)
+_floating_dtypes = (float32, float64)
+_integer_dtypes = (int8, int16, int32, int64, uint8, uint16, uint32, uint64)
+_integer_or_boolean_dtypes = (
+ bool,
+ int8,
+ int16,
+ int32,
+ int64,
+ uint8,
+ uint16,
+ uint32,
+ uint64,
+)
+_numeric_dtypes = (
+ float32,
+ float64,
+ int8,
+ int16,
+ int32,
+ int64,
+ uint8,
+ uint16,
+ uint32,
+ uint64,
+)
+
+_dtype_categories = {
+ "all": _all_dtypes,
+ "numeric": _numeric_dtypes,
+ "integer": _integer_dtypes,
+ "integer or boolean": _integer_or_boolean_dtypes,
+ "boolean": _boolean_dtypes,
+ "floating-point": _floating_dtypes,
+}
+
+
+# Note: the spec defines a restricted type promotion table compared to NumPy.
+# In particular, cross-kind promotions like integer + float or boolean +
+# integer are not allowed, even for functions that accept both kinds.
+# Additionally, NumPy promotes signed integer + uint64 to float64, but this
+# promotion is not allowed here. To be clear, Python scalar int objects are
+# allowed to promote to floating-point dtypes, but only in array operators
+# (see Array._promote_scalar) method in _array_object.py.
+_promotion_table = {
+ (int8, int8): int8,
+ (int8, int16): int16,
+ (int8, int32): int32,
+ (int8, int64): int64,
+ (int16, int8): int16,
+ (int16, int16): int16,
+ (int16, int32): int32,
+ (int16, int64): int64,
+ (int32, int8): int32,
+ (int32, int16): int32,
+ (int32, int32): int32,
+ (int32, int64): int64,
+ (int64, int8): int64,
+ (int64, int16): int64,
+ (int64, int32): int64,
+ (int64, int64): int64,
+ (uint8, uint8): uint8,
+ (uint8, uint16): uint16,
+ (uint8, uint32): uint32,
+ (uint8, uint64): uint64,
+ (uint16, uint8): uint16,
+ (uint16, uint16): uint16,
+ (uint16, uint32): uint32,
+ (uint16, uint64): uint64,
+ (uint32, uint8): uint32,
+ (uint32, uint16): uint32,
+ (uint32, uint32): uint32,
+ (uint32, uint64): uint64,
+ (uint64, uint8): uint64,
+ (uint64, uint16): uint64,
+ (uint64, uint32): uint64,
+ (uint64, uint64): uint64,
+ (int8, uint8): int16,
+ (int8, uint16): int32,
+ (int8, uint32): int64,
+ (int16, uint8): int16,
+ (int16, uint16): int32,
+ (int16, uint32): int64,
+ (int32, uint8): int32,
+ (int32, uint16): int32,
+ (int32, uint32): int64,
+ (int64, uint8): int64,
+ (int64, uint16): int64,
+ (int64, uint32): int64,
+ (uint8, int8): int16,
+ (uint16, int8): int32,
+ (uint32, int8): int64,
+ (uint8, int16): int16,
+ (uint16, int16): int32,
+ (uint32, int16): int64,
+ (uint8, int32): int32,
+ (uint16, int32): int32,
+ (uint32, int32): int64,
+ (uint8, int64): int64,
+ (uint16, int64): int64,
+ (uint32, int64): int64,
+ (float32, float32): float32,
+ (float32, float64): float64,
+ (float64, float32): float64,
+ (float64, float64): float64,
+ (bool, bool): bool,
+}
+
+
+def _result_type(type1, type2):
+ if (type1, type2) in _promotion_table:
+ return _promotion_table[type1, type2]
+ raise TypeError(f"{type1} and {type2} cannot be type promoted together")
diff --git a/venv/lib/python3.9/site-packages/numpy/array_api/_elementwise_functions.py b/venv/lib/python3.9/site-packages/numpy/array_api/_elementwise_functions.py
new file mode 100644
index 00000000..c758a094
--- /dev/null
+++ b/venv/lib/python3.9/site-packages/numpy/array_api/_elementwise_functions.py
@@ -0,0 +1,729 @@
+from __future__ import annotations
+
+from ._dtypes import (
+ _boolean_dtypes,
+ _floating_dtypes,
+ _integer_dtypes,
+ _integer_or_boolean_dtypes,
+ _numeric_dtypes,
+ _result_type,
+)
+from ._array_object import Array
+
+import numpy as np
+
+
+def abs(x: Array, /) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.abs <numpy.abs>`.
+
+ See its docstring for more information.
+ """
+ if x.dtype not in _numeric_dtypes:
+ raise TypeError("Only numeric dtypes are allowed in abs")
+ return Array._new(np.abs(x._array))
+
+
+# Note: the function name is different here
+def acos(x: Array, /) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.arccos <numpy.arccos>`.
+
+ See its docstring for more information.
+ """
+ if x.dtype not in _floating_dtypes:
+ raise TypeError("Only floating-point dtypes are allowed in acos")
+ return Array._new(np.arccos(x._array))
+
+
+# Note: the function name is different here
+def acosh(x: Array, /) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.arccosh <numpy.arccosh>`.
+
+ See its docstring for more information.
+ """
+ if x.dtype not in _floating_dtypes:
+ raise TypeError("Only floating-point dtypes are allowed in acosh")
+ return Array._new(np.arccosh(x._array))
+
+
+def add(x1: Array, x2: Array, /) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.add <numpy.add>`.
+
+ See its docstring for more information.
+ """
+ if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
+ raise TypeError("Only numeric dtypes are allowed in add")
+ # Call result type here just to raise on disallowed type combinations
+ _result_type(x1.dtype, x2.dtype)
+ x1, x2 = Array._normalize_two_args(x1, x2)
+ return Array._new(np.add(x1._array, x2._array))
+
+
+# Note: the function name is different here
+def asin(x: Array, /) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.arcsin <numpy.arcsin>`.
+
+ See its docstring for more information.
+ """
+ if x.dtype not in _floating_dtypes:
+ raise TypeError("Only floating-point dtypes are allowed in asin")
+ return Array._new(np.arcsin(x._array))
+
+
+# Note: the function name is different here
+def asinh(x: Array, /) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.arcsinh <numpy.arcsinh>`.
+
+ See its docstring for more information.
+ """
+ if x.dtype not in _floating_dtypes:
+ raise TypeError("Only floating-point dtypes are allowed in asinh")
+ return Array._new(np.arcsinh(x._array))
+
+
+# Note: the function name is different here
+def atan(x: Array, /) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.arctan <numpy.arctan>`.
+
+ See its docstring for more information.
+ """
+ if x.dtype not in _floating_dtypes:
+ raise TypeError("Only floating-point dtypes are allowed in atan")
+ return Array._new(np.arctan(x._array))
+
+
+# Note: the function name is different here
+def atan2(x1: Array, x2: Array, /) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.arctan2 <numpy.arctan2>`.
+
+ See its docstring for more information.
+ """
+ if x1.dtype not in _floating_dtypes or x2.dtype not in _floating_dtypes:
+ raise TypeError("Only floating-point dtypes are allowed in atan2")
+ # Call result type here just to raise on disallowed type combinations
+ _result_type(x1.dtype, x2.dtype)
+ x1, x2 = Array._normalize_two_args(x1, x2)
+ return Array._new(np.arctan2(x1._array, x2._array))
+
+
+# Note: the function name is different here
+def atanh(x: Array, /) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.arctanh <numpy.arctanh>`.
+
+ See its docstring for more information.
+ """
+ if x.dtype not in _floating_dtypes:
+ raise TypeError("Only floating-point dtypes are allowed in atanh")
+ return Array._new(np.arctanh(x._array))
+
+
+def bitwise_and(x1: Array, x2: Array, /) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.bitwise_and <numpy.bitwise_and>`.
+
+ See its docstring for more information.
+ """
+ if (
+ x1.dtype not in _integer_or_boolean_dtypes
+ or x2.dtype not in _integer_or_boolean_dtypes
+ ):
+ raise TypeError("Only integer or boolean dtypes are allowed in bitwise_and")
+ # Call result type here just to raise on disallowed type combinations
+ _result_type(x1.dtype, x2.dtype)
+ x1, x2 = Array._normalize_two_args(x1, x2)
+ return Array._new(np.bitwise_and(x1._array, x2._array))
+
+
+# Note: the function name is different here
+def bitwise_left_shift(x1: Array, x2: Array, /) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.left_shift <numpy.left_shift>`.
+
+ See its docstring for more information.
+ """
+ if x1.dtype not in _integer_dtypes or x2.dtype not in _integer_dtypes:
+ raise TypeError("Only integer dtypes are allowed in bitwise_left_shift")
+ # Call result type here just to raise on disallowed type combinations
+ _result_type(x1.dtype, x2.dtype)
+ x1, x2 = Array._normalize_two_args(x1, x2)
+ # Note: bitwise_left_shift is only defined for x2 nonnegative.
+ if np.any(x2._array < 0):
+ raise ValueError("bitwise_left_shift(x1, x2) is only defined for x2 >= 0")
+ return Array._new(np.left_shift(x1._array, x2._array))
+
+
+# Note: the function name is different here
+def bitwise_invert(x: Array, /) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.invert <numpy.invert>`.
+
+ See its docstring for more information.
+ """
+ if x.dtype not in _integer_or_boolean_dtypes:
+ raise TypeError("Only integer or boolean dtypes are allowed in bitwise_invert")
+ return Array._new(np.invert(x._array))
+
+
+def bitwise_or(x1: Array, x2: Array, /) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.bitwise_or <numpy.bitwise_or>`.
+
+ See its docstring for more information.
+ """
+ if (
+ x1.dtype not in _integer_or_boolean_dtypes
+ or x2.dtype not in _integer_or_boolean_dtypes
+ ):
+ raise TypeError("Only integer or boolean dtypes are allowed in bitwise_or")
+ # Call result type here just to raise on disallowed type combinations
+ _result_type(x1.dtype, x2.dtype)
+ x1, x2 = Array._normalize_two_args(x1, x2)
+ return Array._new(np.bitwise_or(x1._array, x2._array))
+
+
+# Note: the function name is different here
+def bitwise_right_shift(x1: Array, x2: Array, /) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.right_shift <numpy.right_shift>`.
+
+ See its docstring for more information.
+ """
+ if x1.dtype not in _integer_dtypes or x2.dtype not in _integer_dtypes:
+ raise TypeError("Only integer dtypes are allowed in bitwise_right_shift")
+ # Call result type here just to raise on disallowed type combinations
+ _result_type(x1.dtype, x2.dtype)
+ x1, x2 = Array._normalize_two_args(x1, x2)
+ # Note: bitwise_right_shift is only defined for x2 nonnegative.
+ if np.any(x2._array < 0):
+ raise ValueError("bitwise_right_shift(x1, x2) is only defined for x2 >= 0")
+ return Array._new(np.right_shift(x1._array, x2._array))
+
+
+def bitwise_xor(x1: Array, x2: Array, /) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.bitwise_xor <numpy.bitwise_xor>`.
+
+ See its docstring for more information.
+ """
+ if (
+ x1.dtype not in _integer_or_boolean_dtypes
+ or x2.dtype not in _integer_or_boolean_dtypes
+ ):
+ raise TypeError("Only integer or boolean dtypes are allowed in bitwise_xor")
+ # Call result type here just to raise on disallowed type combinations
+ _result_type(x1.dtype, x2.dtype)
+ x1, x2 = Array._normalize_two_args(x1, x2)
+ return Array._new(np.bitwise_xor(x1._array, x2._array))
+
+
+def ceil(x: Array, /) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.ceil <numpy.ceil>`.
+
+ See its docstring for more information.
+ """
+ if x.dtype not in _numeric_dtypes:
+ raise TypeError("Only numeric dtypes are allowed in ceil")
+ if x.dtype in _integer_dtypes:
+ # Note: The return dtype of ceil is the same as the input
+ return x
+ return Array._new(np.ceil(x._array))
+
+
+def cos(x: Array, /) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.cos <numpy.cos>`.
+
+ See its docstring for more information.
+ """
+ if x.dtype not in _floating_dtypes:
+ raise TypeError("Only floating-point dtypes are allowed in cos")
+ return Array._new(np.cos(x._array))
+
+
+def cosh(x: Array, /) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.cosh <numpy.cosh>`.
+
+ See its docstring for more information.
+ """
+ if x.dtype not in _floating_dtypes:
+ raise TypeError("Only floating-point dtypes are allowed in cosh")
+ return Array._new(np.cosh(x._array))
+
+
+def divide(x1: Array, x2: Array, /) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.divide <numpy.divide>`.
+
+ See its docstring for more information.
+ """
+ if x1.dtype not in _floating_dtypes or x2.dtype not in _floating_dtypes:
+ raise TypeError("Only floating-point dtypes are allowed in divide")
+ # Call result type here just to raise on disallowed type combinations
+ _result_type(x1.dtype, x2.dtype)
+ x1, x2 = Array._normalize_two_args(x1, x2)
+ return Array._new(np.divide(x1._array, x2._array))
+
+
+def equal(x1: Array, x2: Array, /) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.equal <numpy.equal>`.
+
+ See its docstring for more information.
+ """
+ # Call result type here just to raise on disallowed type combinations
+ _result_type(x1.dtype, x2.dtype)
+ x1, x2 = Array._normalize_two_args(x1, x2)
+ return Array._new(np.equal(x1._array, x2._array))
+
+
+def exp(x: Array, /) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.exp <numpy.exp>`.
+
+ See its docstring for more information.
+ """
+ if x.dtype not in _floating_dtypes:
+ raise TypeError("Only floating-point dtypes are allowed in exp")
+ return Array._new(np.exp(x._array))
+
+
+def expm1(x: Array, /) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.expm1 <numpy.expm1>`.
+
+ See its docstring for more information.
+ """
+ if x.dtype not in _floating_dtypes:
+ raise TypeError("Only floating-point dtypes are allowed in expm1")
+ return Array._new(np.expm1(x._array))
+
+
+def floor(x: Array, /) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.floor <numpy.floor>`.
+
+ See its docstring for more information.
+ """
+ if x.dtype not in _numeric_dtypes:
+ raise TypeError("Only numeric dtypes are allowed in floor")
+ if x.dtype in _integer_dtypes:
+ # Note: The return dtype of floor is the same as the input
+ return x
+ return Array._new(np.floor(x._array))
+
+
+def floor_divide(x1: Array, x2: Array, /) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.floor_divide <numpy.floor_divide>`.
+
+ See its docstring for more information.
+ """
+ if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
+ raise TypeError("Only numeric dtypes are allowed in floor_divide")
+ # Call result type here just to raise on disallowed type combinations
+ _result_type(x1.dtype, x2.dtype)
+ x1, x2 = Array._normalize_two_args(x1, x2)
+ return Array._new(np.floor_divide(x1._array, x2._array))
+
+
+def greater(x1: Array, x2: Array, /) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.greater <numpy.greater>`.
+
+ See its docstring for more information.
+ """
+ if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
+ raise TypeError("Only numeric dtypes are allowed in greater")
+ # Call result type here just to raise on disallowed type combinations
+ _result_type(x1.dtype, x2.dtype)
+ x1, x2 = Array._normalize_two_args(x1, x2)
+ return Array._new(np.greater(x1._array, x2._array))
+
+
+def greater_equal(x1: Array, x2: Array, /) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.greater_equal <numpy.greater_equal>`.
+
+ See its docstring for more information.
+ """
+ if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
+ raise TypeError("Only numeric dtypes are allowed in greater_equal")
+ # Call result type here just to raise on disallowed type combinations
+ _result_type(x1.dtype, x2.dtype)
+ x1, x2 = Array._normalize_two_args(x1, x2)
+ return Array._new(np.greater_equal(x1._array, x2._array))
+
+
+def isfinite(x: Array, /) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.isfinite <numpy.isfinite>`.
+
+ See its docstring for more information.
+ """
+ if x.dtype not in _numeric_dtypes:
+ raise TypeError("Only numeric dtypes are allowed in isfinite")
+ return Array._new(np.isfinite(x._array))
+
+
+def isinf(x: Array, /) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.isinf <numpy.isinf>`.
+
+ See its docstring for more information.
+ """
+ if x.dtype not in _numeric_dtypes:
+ raise TypeError("Only numeric dtypes are allowed in isinf")
+ return Array._new(np.isinf(x._array))
+
+
+def isnan(x: Array, /) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.isnan <numpy.isnan>`.
+
+ See its docstring for more information.
+ """
+ if x.dtype not in _numeric_dtypes:
+ raise TypeError("Only numeric dtypes are allowed in isnan")
+ return Array._new(np.isnan(x._array))
+
+
+def less(x1: Array, x2: Array, /) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.less <numpy.less>`.
+
+ See its docstring for more information.
+ """
+ if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
+ raise TypeError("Only numeric dtypes are allowed in less")
+ # Call result type here just to raise on disallowed type combinations
+ _result_type(x1.dtype, x2.dtype)
+ x1, x2 = Array._normalize_two_args(x1, x2)
+ return Array._new(np.less(x1._array, x2._array))
+
+
+def less_equal(x1: Array, x2: Array, /) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.less_equal <numpy.less_equal>`.
+
+ See its docstring for more information.
+ """
+ if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
+ raise TypeError("Only numeric dtypes are allowed in less_equal")
+ # Call result type here just to raise on disallowed type combinations
+ _result_type(x1.dtype, x2.dtype)
+ x1, x2 = Array._normalize_two_args(x1, x2)
+ return Array._new(np.less_equal(x1._array, x2._array))
+
+
+def log(x: Array, /) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.log <numpy.log>`.
+
+ See its docstring for more information.
+ """
+ if x.dtype not in _floating_dtypes:
+ raise TypeError("Only floating-point dtypes are allowed in log")
+ return Array._new(np.log(x._array))
+
+
+def log1p(x: Array, /) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.log1p <numpy.log1p>`.
+
+ See its docstring for more information.
+ """
+ if x.dtype not in _floating_dtypes:
+ raise TypeError("Only floating-point dtypes are allowed in log1p")
+ return Array._new(np.log1p(x._array))
+
+
+def log2(x: Array, /) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.log2 <numpy.log2>`.
+
+ See its docstring for more information.
+ """
+ if x.dtype not in _floating_dtypes:
+ raise TypeError("Only floating-point dtypes are allowed in log2")
+ return Array._new(np.log2(x._array))
+
+
+def log10(x: Array, /) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.log10 <numpy.log10>`.
+
+ See its docstring for more information.
+ """
+ if x.dtype not in _floating_dtypes:
+ raise TypeError("Only floating-point dtypes are allowed in log10")
+ return Array._new(np.log10(x._array))
+
+
+def logaddexp(x1: Array, x2: Array) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.logaddexp <numpy.logaddexp>`.
+
+ See its docstring for more information.
+ """
+ if x1.dtype not in _floating_dtypes or x2.dtype not in _floating_dtypes:
+ raise TypeError("Only floating-point dtypes are allowed in logaddexp")
+ # Call result type here just to raise on disallowed type combinations
+ _result_type(x1.dtype, x2.dtype)
+ x1, x2 = Array._normalize_two_args(x1, x2)
+ return Array._new(np.logaddexp(x1._array, x2._array))
+
+
+def logical_and(x1: Array, x2: Array, /) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.logical_and <numpy.logical_and>`.
+
+ See its docstring for more information.
+ """
+ if x1.dtype not in _boolean_dtypes or x2.dtype not in _boolean_dtypes:
+ raise TypeError("Only boolean dtypes are allowed in logical_and")
+ # Call result type here just to raise on disallowed type combinations
+ _result_type(x1.dtype, x2.dtype)
+ x1, x2 = Array._normalize_two_args(x1, x2)
+ return Array._new(np.logical_and(x1._array, x2._array))
+
+
+def logical_not(x: Array, /) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.logical_not <numpy.logical_not>`.
+
+ See its docstring for more information.
+ """
+ if x.dtype not in _boolean_dtypes:
+ raise TypeError("Only boolean dtypes are allowed in logical_not")
+ return Array._new(np.logical_not(x._array))
+
+
+def logical_or(x1: Array, x2: Array, /) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.logical_or <numpy.logical_or>`.
+
+ See its docstring for more information.
+ """
+ if x1.dtype not in _boolean_dtypes or x2.dtype not in _boolean_dtypes:
+ raise TypeError("Only boolean dtypes are allowed in logical_or")
+ # Call result type here just to raise on disallowed type combinations
+ _result_type(x1.dtype, x2.dtype)
+ x1, x2 = Array._normalize_two_args(x1, x2)
+ return Array._new(np.logical_or(x1._array, x2._array))
+
+
+def logical_xor(x1: Array, x2: Array, /) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.logical_xor <numpy.logical_xor>`.
+
+ See its docstring for more information.
+ """
+ if x1.dtype not in _boolean_dtypes or x2.dtype not in _boolean_dtypes:
+ raise TypeError("Only boolean dtypes are allowed in logical_xor")
+ # Call result type here just to raise on disallowed type combinations
+ _result_type(x1.dtype, x2.dtype)
+ x1, x2 = Array._normalize_two_args(x1, x2)
+ return Array._new(np.logical_xor(x1._array, x2._array))
+
+
+def multiply(x1: Array, x2: Array, /) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.multiply <numpy.multiply>`.
+
+ See its docstring for more information.
+ """
+ if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
+ raise TypeError("Only numeric dtypes are allowed in multiply")
+ # Call result type here just to raise on disallowed type combinations
+ _result_type(x1.dtype, x2.dtype)
+ x1, x2 = Array._normalize_two_args(x1, x2)
+ return Array._new(np.multiply(x1._array, x2._array))
+
+
+def negative(x: Array, /) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.negative <numpy.negative>`.
+
+ See its docstring for more information.
+ """
+ if x.dtype not in _numeric_dtypes:
+ raise TypeError("Only numeric dtypes are allowed in negative")
+ return Array._new(np.negative(x._array))
+
+
+def not_equal(x1: Array, x2: Array, /) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.not_equal <numpy.not_equal>`.
+
+ See its docstring for more information.
+ """
+ # Call result type here just to raise on disallowed type combinations
+ _result_type(x1.dtype, x2.dtype)
+ x1, x2 = Array._normalize_two_args(x1, x2)
+ return Array._new(np.not_equal(x1._array, x2._array))
+
+
+def positive(x: Array, /) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.positive <numpy.positive>`.
+
+ See its docstring for more information.
+ """
+ if x.dtype not in _numeric_dtypes:
+ raise TypeError("Only numeric dtypes are allowed in positive")
+ return Array._new(np.positive(x._array))
+
+
+# Note: the function name is different here
+def pow(x1: Array, x2: Array, /) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.power <numpy.power>`.
+
+ See its docstring for more information.
+ """
+ if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
+ raise TypeError("Only numeric dtypes are allowed in pow")
+ # Call result type here just to raise on disallowed type combinations
+ _result_type(x1.dtype, x2.dtype)
+ x1, x2 = Array._normalize_two_args(x1, x2)
+ return Array._new(np.power(x1._array, x2._array))
+
+
+def remainder(x1: Array, x2: Array, /) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.remainder <numpy.remainder>`.
+
+ See its docstring for more information.
+ """
+ if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
+ raise TypeError("Only numeric dtypes are allowed in remainder")
+ # Call result type here just to raise on disallowed type combinations
+ _result_type(x1.dtype, x2.dtype)
+ x1, x2 = Array._normalize_two_args(x1, x2)
+ return Array._new(np.remainder(x1._array, x2._array))
+
+
+def round(x: Array, /) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.round <numpy.round>`.
+
+ See its docstring for more information.
+ """
+ if x.dtype not in _numeric_dtypes:
+ raise TypeError("Only numeric dtypes are allowed in round")
+ return Array._new(np.round(x._array))
+
+
+def sign(x: Array, /) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.sign <numpy.sign>`.
+
+ See its docstring for more information.
+ """
+ if x.dtype not in _numeric_dtypes:
+ raise TypeError("Only numeric dtypes are allowed in sign")
+ return Array._new(np.sign(x._array))
+
+
+def sin(x: Array, /) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.sin <numpy.sin>`.
+
+ See its docstring for more information.
+ """
+ if x.dtype not in _floating_dtypes:
+ raise TypeError("Only floating-point dtypes are allowed in sin")
+ return Array._new(np.sin(x._array))
+
+
+def sinh(x: Array, /) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.sinh <numpy.sinh>`.
+
+ See its docstring for more information.
+ """
+ if x.dtype not in _floating_dtypes:
+ raise TypeError("Only floating-point dtypes are allowed in sinh")
+ return Array._new(np.sinh(x._array))
+
+
+def square(x: Array, /) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.square <numpy.square>`.
+
+ See its docstring for more information.
+ """
+ if x.dtype not in _numeric_dtypes:
+ raise TypeError("Only numeric dtypes are allowed in square")
+ return Array._new(np.square(x._array))
+
+
+def sqrt(x: Array, /) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.sqrt <numpy.sqrt>`.
+
+ See its docstring for more information.
+ """
+ if x.dtype not in _floating_dtypes:
+ raise TypeError("Only floating-point dtypes are allowed in sqrt")
+ return Array._new(np.sqrt(x._array))
+
+
+def subtract(x1: Array, x2: Array, /) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.subtract <numpy.subtract>`.
+
+ See its docstring for more information.
+ """
+ if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
+ raise TypeError("Only numeric dtypes are allowed in subtract")
+ # Call result type here just to raise on disallowed type combinations
+ _result_type(x1.dtype, x2.dtype)
+ x1, x2 = Array._normalize_two_args(x1, x2)
+ return Array._new(np.subtract(x1._array, x2._array))
+
+
+def tan(x: Array, /) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.tan <numpy.tan>`.
+
+ See its docstring for more information.
+ """
+ if x.dtype not in _floating_dtypes:
+ raise TypeError("Only floating-point dtypes are allowed in tan")
+ return Array._new(np.tan(x._array))
+
+
+def tanh(x: Array, /) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.tanh <numpy.tanh>`.
+
+ See its docstring for more information.
+ """
+ if x.dtype not in _floating_dtypes:
+ raise TypeError("Only floating-point dtypes are allowed in tanh")
+ return Array._new(np.tanh(x._array))
+
+
+def trunc(x: Array, /) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.trunc <numpy.trunc>`.
+
+ See its docstring for more information.
+ """
+ if x.dtype not in _numeric_dtypes:
+ raise TypeError("Only numeric dtypes are allowed in trunc")
+ if x.dtype in _integer_dtypes:
+ # Note: The return dtype of trunc is the same as the input
+ return x
+ return Array._new(np.trunc(x._array))
diff --git a/venv/lib/python3.9/site-packages/numpy/array_api/_manipulation_functions.py b/venv/lib/python3.9/site-packages/numpy/array_api/_manipulation_functions.py
new file mode 100644
index 00000000..7991f46a
--- /dev/null
+++ b/venv/lib/python3.9/site-packages/numpy/array_api/_manipulation_functions.py
@@ -0,0 +1,98 @@
+from __future__ import annotations
+
+from ._array_object import Array
+from ._data_type_functions import result_type
+
+from typing import List, Optional, Tuple, Union
+
+import numpy as np
+
+# Note: the function name is different here
+def concat(
+ arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: Optional[int] = 0
+) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.concatenate <numpy.concatenate>`.
+
+ See its docstring for more information.
+ """
+ # Note: Casting rules here are different from the np.concatenate default
+ # (no for scalars with axis=None, no cross-kind casting)
+ dtype = result_type(*arrays)
+ arrays = tuple(a._array for a in arrays)
+ return Array._new(np.concatenate(arrays, axis=axis, dtype=dtype))
+
+
+def expand_dims(x: Array, /, *, axis: int) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.expand_dims <numpy.expand_dims>`.
+
+ See its docstring for more information.
+ """
+ return Array._new(np.expand_dims(x._array, axis))
+
+
+def flip(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.flip <numpy.flip>`.
+
+ See its docstring for more information.
+ """
+ return Array._new(np.flip(x._array, axis=axis))
+
+
+# Note: The function name is different here (see also matrix_transpose).
+# Unlike transpose(), the axes argument is required.
+def permute_dims(x: Array, /, axes: Tuple[int, ...]) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.transpose <numpy.transpose>`.
+
+ See its docstring for more information.
+ """
+ return Array._new(np.transpose(x._array, axes))
+
+
+# Note: the optional argument is called 'shape', not 'newshape'
+def reshape(x: Array, /, shape: Tuple[int, ...]) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.reshape <numpy.reshape>`.
+
+ See its docstring for more information.
+ """
+ return Array._new(np.reshape(x._array, shape))
+
+
+def roll(
+ x: Array,
+ /,
+ shift: Union[int, Tuple[int, ...]],
+ *,
+ axis: Optional[Union[int, Tuple[int, ...]]] = None,
+) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.roll <numpy.roll>`.
+
+ See its docstring for more information.
+ """
+ return Array._new(np.roll(x._array, shift, axis=axis))
+
+
+def squeeze(x: Array, /, axis: Union[int, Tuple[int, ...]]) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.squeeze <numpy.squeeze>`.
+
+ See its docstring for more information.
+ """
+ return Array._new(np.squeeze(x._array, axis=axis))
+
+
+def stack(arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: int = 0) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.stack <numpy.stack>`.
+
+ See its docstring for more information.
+ """
+ # Call result type here just to raise on disallowed type combinations
+ result_type(*arrays)
+ arrays = tuple(a._array for a in arrays)
+ return Array._new(np.stack(arrays, axis=axis))
diff --git a/venv/lib/python3.9/site-packages/numpy/array_api/_searching_functions.py b/venv/lib/python3.9/site-packages/numpy/array_api/_searching_functions.py
new file mode 100644
index 00000000..40f5a4d2
--- /dev/null
+++ b/venv/lib/python3.9/site-packages/numpy/array_api/_searching_functions.py
@@ -0,0 +1,47 @@
+from __future__ import annotations
+
+from ._array_object import Array
+from ._dtypes import _result_type
+
+from typing import Optional, Tuple
+
+import numpy as np
+
+
+def argmax(x: Array, /, *, axis: Optional[int] = None, keepdims: bool = False) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.argmax <numpy.argmax>`.
+
+ See its docstring for more information.
+ """
+ return Array._new(np.asarray(np.argmax(x._array, axis=axis, keepdims=keepdims)))
+
+
+def argmin(x: Array, /, *, axis: Optional[int] = None, keepdims: bool = False) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.argmin <numpy.argmin>`.
+
+ See its docstring for more information.
+ """
+ return Array._new(np.asarray(np.argmin(x._array, axis=axis, keepdims=keepdims)))
+
+
+def nonzero(x: Array, /) -> Tuple[Array, ...]:
+ """
+ Array API compatible wrapper for :py:func:`np.nonzero <numpy.nonzero>`.
+
+ See its docstring for more information.
+ """
+ return tuple(Array._new(i) for i in np.nonzero(x._array))
+
+
+def where(condition: Array, x1: Array, x2: Array, /) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.where <numpy.where>`.
+
+ See its docstring for more information.
+ """
+ # Call result type here just to raise on disallowed type combinations
+ _result_type(x1.dtype, x2.dtype)
+ x1, x2 = Array._normalize_two_args(x1, x2)
+ return Array._new(np.where(condition._array, x1._array, x2._array))
diff --git a/venv/lib/python3.9/site-packages/numpy/array_api/_set_functions.py b/venv/lib/python3.9/site-packages/numpy/array_api/_set_functions.py
new file mode 100644
index 00000000..0b4132cf
--- /dev/null
+++ b/venv/lib/python3.9/site-packages/numpy/array_api/_set_functions.py
@@ -0,0 +1,106 @@
+from __future__ import annotations
+
+from ._array_object import Array
+
+from typing import NamedTuple
+
+import numpy as np
+
+# Note: np.unique() is split into four functions in the array API:
+# unique_all, unique_counts, unique_inverse, and unique_values (this is done
+# to remove polymorphic return types).
+
+# Note: The various unique() functions are supposed to return multiple NaNs.
+# This does not match the NumPy behavior, however, this is currently left as a
+# TODO in this implementation as this behavior may be reverted in np.unique().
+# See https://github.com/numpy/numpy/issues/20326.
+
+# Note: The functions here return a namedtuple (np.unique() returns a normal
+# tuple).
+
+class UniqueAllResult(NamedTuple):
+ values: Array
+ indices: Array
+ inverse_indices: Array
+ counts: Array
+
+
+class UniqueCountsResult(NamedTuple):
+ values: Array
+ counts: Array
+
+
+class UniqueInverseResult(NamedTuple):
+ values: Array
+ inverse_indices: Array
+
+
+def unique_all(x: Array, /) -> UniqueAllResult:
+ """
+ Array API compatible wrapper for :py:func:`np.unique <numpy.unique>`.
+
+ See its docstring for more information.
+ """
+ values, indices, inverse_indices, counts = np.unique(
+ x._array,
+ return_counts=True,
+ return_index=True,
+ return_inverse=True,
+ equal_nan=False,
+ )
+ # np.unique() flattens inverse indices, but they need to share x's shape
+ # See https://github.com/numpy/numpy/issues/20638
+ inverse_indices = inverse_indices.reshape(x.shape)
+ return UniqueAllResult(
+ Array._new(values),
+ Array._new(indices),
+ Array._new(inverse_indices),
+ Array._new(counts),
+ )
+
+
+def unique_counts(x: Array, /) -> UniqueCountsResult:
+ res = np.unique(
+ x._array,
+ return_counts=True,
+ return_index=False,
+ return_inverse=False,
+ equal_nan=False,
+ )
+
+ return UniqueCountsResult(*[Array._new(i) for i in res])
+
+
+def unique_inverse(x: Array, /) -> UniqueInverseResult:
+ """
+ Array API compatible wrapper for :py:func:`np.unique <numpy.unique>`.
+
+ See its docstring for more information.
+ """
+ values, inverse_indices = np.unique(
+ x._array,
+ return_counts=False,
+ return_index=False,
+ return_inverse=True,
+ equal_nan=False,
+ )
+ # np.unique() flattens inverse indices, but they need to share x's shape
+ # See https://github.com/numpy/numpy/issues/20638
+ inverse_indices = inverse_indices.reshape(x.shape)
+ return UniqueInverseResult(Array._new(values), Array._new(inverse_indices))
+
+
+def unique_values(x: Array, /) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.unique <numpy.unique>`.
+
+ See its docstring for more information.
+ """
+ res = np.unique(
+ x._array,
+ return_counts=False,
+ return_index=False,
+ return_inverse=False,
+ equal_nan=False,
+ )
+ return Array._new(res)
diff --git a/venv/lib/python3.9/site-packages/numpy/array_api/_sorting_functions.py b/venv/lib/python3.9/site-packages/numpy/array_api/_sorting_functions.py
new file mode 100644
index 00000000..afbb412f
--- /dev/null
+++ b/venv/lib/python3.9/site-packages/numpy/array_api/_sorting_functions.py
@@ -0,0 +1,49 @@
+from __future__ import annotations
+
+from ._array_object import Array
+
+import numpy as np
+
+
+# Note: the descending keyword argument is new in this function
+def argsort(
+ x: Array, /, *, axis: int = -1, descending: bool = False, stable: bool = True
+) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.argsort <numpy.argsort>`.
+
+ See its docstring for more information.
+ """
+ # Note: this keyword argument is different, and the default is different.
+ kind = "stable" if stable else "quicksort"
+ if not descending:
+ res = np.argsort(x._array, axis=axis, kind=kind)
+ else:
+ # As NumPy has no native descending sort, we imitate it here. Note that
+ # simply flipping the results of np.argsort(x._array, ...) would not
+ # respect the relative order like it would in native descending sorts.
+ res = np.flip(
+ np.argsort(np.flip(x._array, axis=axis), axis=axis, kind=kind),
+ axis=axis,
+ )
+ # Rely on flip()/argsort() to validate axis
+ normalised_axis = axis if axis >= 0 else x.ndim + axis
+ max_i = x.shape[normalised_axis] - 1
+ res = max_i - res
+ return Array._new(res)
+
+# Note: the descending keyword argument is new in this function
+def sort(
+ x: Array, /, *, axis: int = -1, descending: bool = False, stable: bool = True
+) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.sort <numpy.sort>`.
+
+ See its docstring for more information.
+ """
+ # Note: this keyword argument is different, and the default is different.
+ kind = "stable" if stable else "quicksort"
+ res = np.sort(x._array, axis=axis, kind=kind)
+ if descending:
+ res = np.flip(res, axis=axis)
+ return Array._new(res)
diff --git a/venv/lib/python3.9/site-packages/numpy/array_api/_statistical_functions.py b/venv/lib/python3.9/site-packages/numpy/array_api/_statistical_functions.py
new file mode 100644
index 00000000..5bc831ac
--- /dev/null
+++ b/venv/lib/python3.9/site-packages/numpy/array_api/_statistical_functions.py
@@ -0,0 +1,115 @@
+from __future__ import annotations
+
+from ._dtypes import (
+ _floating_dtypes,
+ _numeric_dtypes,
+)
+from ._array_object import Array
+from ._creation_functions import asarray
+from ._dtypes import float32, float64
+
+from typing import TYPE_CHECKING, Optional, Tuple, Union
+
+if TYPE_CHECKING:
+ from ._typing import Dtype
+
+import numpy as np
+
+
+def max(
+ x: Array,
+ /,
+ *,
+ axis: Optional[Union[int, Tuple[int, ...]]] = None,
+ keepdims: bool = False,
+) -> Array:
+ if x.dtype not in _numeric_dtypes:
+ raise TypeError("Only numeric dtypes are allowed in max")
+ return Array._new(np.max(x._array, axis=axis, keepdims=keepdims))
+
+
+def mean(
+ x: Array,
+ /,
+ *,
+ axis: Optional[Union[int, Tuple[int, ...]]] = None,
+ keepdims: bool = False,
+) -> Array:
+ if x.dtype not in _floating_dtypes:
+ raise TypeError("Only floating-point dtypes are allowed in mean")
+ return Array._new(np.mean(x._array, axis=axis, keepdims=keepdims))
+
+
+def min(
+ x: Array,
+ /,
+ *,
+ axis: Optional[Union[int, Tuple[int, ...]]] = None,
+ keepdims: bool = False,
+) -> Array:
+ if x.dtype not in _numeric_dtypes:
+ raise TypeError("Only numeric dtypes are allowed in min")
+ return Array._new(np.min(x._array, axis=axis, keepdims=keepdims))
+
+
+def prod(
+ x: Array,
+ /,
+ *,
+ axis: Optional[Union[int, Tuple[int, ...]]] = None,
+ dtype: Optional[Dtype] = None,
+ keepdims: bool = False,
+) -> Array:
+ if x.dtype not in _numeric_dtypes:
+ raise TypeError("Only numeric dtypes are allowed in prod")
+ # Note: sum() and prod() always upcast float32 to float64 for dtype=None
+ # We need to do so here before computing the product to avoid overflow
+ if dtype is None and x.dtype == float32:
+ dtype = float64
+ return Array._new(np.prod(x._array, dtype=dtype, axis=axis, keepdims=keepdims))
+
+
+def std(
+ x: Array,
+ /,
+ *,
+ axis: Optional[Union[int, Tuple[int, ...]]] = None,
+ correction: Union[int, float] = 0.0,
+ keepdims: bool = False,
+) -> Array:
+ # Note: the keyword argument correction is different here
+ if x.dtype not in _floating_dtypes:
+ raise TypeError("Only floating-point dtypes are allowed in std")
+ return Array._new(np.std(x._array, axis=axis, ddof=correction, keepdims=keepdims))
+
+
+def sum(
+ x: Array,
+ /,
+ *,
+ axis: Optional[Union[int, Tuple[int, ...]]] = None,
+ dtype: Optional[Dtype] = None,
+ keepdims: bool = False,
+) -> Array:
+ if x.dtype not in _numeric_dtypes:
+ raise TypeError("Only numeric dtypes are allowed in sum")
+ # Note: sum() and prod() always upcast integers to (u)int64 and float32 to
+ # float64 for dtype=None. `np.sum` does that too for integers, but not for
+ # float32, so we need to special-case it here
+ if dtype is None and x.dtype == float32:
+ dtype = float64
+ return Array._new(np.sum(x._array, axis=axis, dtype=dtype, keepdims=keepdims))
+
+
+def var(
+ x: Array,
+ /,
+ *,
+ axis: Optional[Union[int, Tuple[int, ...]]] = None,
+ correction: Union[int, float] = 0.0,
+ keepdims: bool = False,
+) -> Array:
+ # Note: the keyword argument correction is different here
+ if x.dtype not in _floating_dtypes:
+ raise TypeError("Only floating-point dtypes are allowed in var")
+ return Array._new(np.var(x._array, axis=axis, ddof=correction, keepdims=keepdims))
diff --git a/venv/lib/python3.9/site-packages/numpy/array_api/_typing.py b/venv/lib/python3.9/site-packages/numpy/array_api/_typing.py
new file mode 100644
index 00000000..dfa87b35
--- /dev/null
+++ b/venv/lib/python3.9/site-packages/numpy/array_api/_typing.py
@@ -0,0 +1,74 @@
+"""
+This file defines the types for type annotations.
+
+These names aren't part of the module namespace, but they are used in the
+annotations in the function signatures. The functions in the module are only
+valid for inputs that match the given type annotations.
+"""
+
+from __future__ import annotations
+
+__all__ = [
+ "Array",
+ "Device",
+ "Dtype",
+ "SupportsDLPack",
+ "SupportsBufferProtocol",
+ "PyCapsule",
+]
+
+import sys
+from typing import (
+ Any,
+ Literal,
+ Sequence,
+ Type,
+ Union,
+ TYPE_CHECKING,
+ TypeVar,
+ Protocol,
+)
+
+from ._array_object import Array
+from numpy import (
+ dtype,
+ int8,
+ int16,
+ int32,
+ int64,
+ uint8,
+ uint16,
+ uint32,
+ uint64,
+ float32,
+ float64,
+)
+
+_T_co = TypeVar("_T_co", covariant=True)
+
+class NestedSequence(Protocol[_T_co]):
+ def __getitem__(self, key: int, /) -> _T_co | NestedSequence[_T_co]: ...
+ def __len__(self, /) -> int: ...
+
+Device = Literal["cpu"]
+if TYPE_CHECKING or sys.version_info >= (3, 9):
+ Dtype = dtype[Union[
+ int8,
+ int16,
+ int32,
+ int64,
+ uint8,
+ uint16,
+ uint32,
+ uint64,
+ float32,
+ float64,
+ ]]
+else:
+ Dtype = dtype
+
+SupportsBufferProtocol = Any
+PyCapsule = Any
+
+class SupportsDLPack(Protocol):
+ def __dlpack__(self, /, *, stream: None = ...) -> PyCapsule: ...
diff --git a/venv/lib/python3.9/site-packages/numpy/array_api/_utility_functions.py b/venv/lib/python3.9/site-packages/numpy/array_api/_utility_functions.py
new file mode 100644
index 00000000..5ecb4bd9
--- /dev/null
+++ b/venv/lib/python3.9/site-packages/numpy/array_api/_utility_functions.py
@@ -0,0 +1,37 @@
+from __future__ import annotations
+
+from ._array_object import Array
+
+from typing import Optional, Tuple, Union
+
+import numpy as np
+
+
+def all(
+ x: Array,
+ /,
+ *,
+ axis: Optional[Union[int, Tuple[int, ...]]] = None,
+ keepdims: bool = False,
+) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.all <numpy.all>`.
+
+ See its docstring for more information.
+ """
+ return Array._new(np.asarray(np.all(x._array, axis=axis, keepdims=keepdims)))
+
+
+def any(
+ x: Array,
+ /,
+ *,
+ axis: Optional[Union[int, Tuple[int, ...]]] = None,
+ keepdims: bool = False,
+) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.any <numpy.any>`.
+
+ See its docstring for more information.
+ """
+ return Array._new(np.asarray(np.any(x._array, axis=axis, keepdims=keepdims)))
diff --git a/venv/lib/python3.9/site-packages/numpy/array_api/linalg.py b/venv/lib/python3.9/site-packages/numpy/array_api/linalg.py
new file mode 100644
index 00000000..d214046e
--- /dev/null
+++ b/venv/lib/python3.9/site-packages/numpy/array_api/linalg.py
@@ -0,0 +1,446 @@
+from __future__ import annotations
+
+from ._dtypes import _floating_dtypes, _numeric_dtypes
+from ._manipulation_functions import reshape
+from ._array_object import Array
+
+from ..core.numeric import normalize_axis_tuple
+
+from typing import TYPE_CHECKING
+if TYPE_CHECKING:
+ from ._typing import Literal, Optional, Sequence, Tuple, Union
+
+from typing import NamedTuple
+
+import numpy.linalg
+import numpy as np
+
+class EighResult(NamedTuple):
+ eigenvalues: Array
+ eigenvectors: Array
+
+class QRResult(NamedTuple):
+ Q: Array
+ R: Array
+
+class SlogdetResult(NamedTuple):
+ sign: Array
+ logabsdet: Array
+
+class SVDResult(NamedTuple):
+ U: Array
+ S: Array
+ Vh: Array
+
+# Note: the inclusion of the upper keyword is different from
+# np.linalg.cholesky, which does not have it.
+def cholesky(x: Array, /, *, upper: bool = False) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.linalg.cholesky <numpy.linalg.cholesky>`.
+
+ See its docstring for more information.
+ """
+ # Note: the restriction to floating-point dtypes only is different from
+ # np.linalg.cholesky.
+ if x.dtype not in _floating_dtypes:
+ raise TypeError('Only floating-point dtypes are allowed in cholesky')
+ L = np.linalg.cholesky(x._array)
+ if upper:
+ return Array._new(L).mT
+ return Array._new(L)
+
+# Note: cross is the numpy top-level namespace, not np.linalg
+def cross(x1: Array, x2: Array, /, *, axis: int = -1) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.cross <numpy.cross>`.
+
+ See its docstring for more information.
+ """
+ if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
+ raise TypeError('Only numeric dtypes are allowed in cross')
+ # Note: this is different from np.cross(), which broadcasts
+ if x1.shape != x2.shape:
+ raise ValueError('x1 and x2 must have the same shape')
+ if x1.ndim == 0:
+ raise ValueError('cross() requires arrays of dimension at least 1')
+ # Note: this is different from np.cross(), which allows dimension 2
+ if x1.shape[axis] != 3:
+ raise ValueError('cross() dimension must equal 3')
+ return Array._new(np.cross(x1._array, x2._array, axis=axis))
+
+def det(x: Array, /) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.linalg.det <numpy.linalg.det>`.
+
+ See its docstring for more information.
+ """
+ # Note: the restriction to floating-point dtypes only is different from
+ # np.linalg.det.
+ if x.dtype not in _floating_dtypes:
+ raise TypeError('Only floating-point dtypes are allowed in det')
+ return Array._new(np.linalg.det(x._array))
+
+# Note: diagonal is the numpy top-level namespace, not np.linalg
+def diagonal(x: Array, /, *, offset: int = 0) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.diagonal <numpy.diagonal>`.
+
+ See its docstring for more information.
+ """
+ # Note: diagonal always operates on the last two axes, whereas np.diagonal
+ # operates on the first two axes by default
+ return Array._new(np.diagonal(x._array, offset=offset, axis1=-2, axis2=-1))
+
+
+def eigh(x: Array, /) -> EighResult:
+ """
+ Array API compatible wrapper for :py:func:`np.linalg.eigh <numpy.linalg.eigh>`.
+
+ See its docstring for more information.
+ """
+ # Note: the restriction to floating-point dtypes only is different from
+ # np.linalg.eigh.
+ if x.dtype not in _floating_dtypes:
+ raise TypeError('Only floating-point dtypes are allowed in eigh')
+
+ # Note: the return type here is a namedtuple, which is different from
+ # np.eigh, which only returns a tuple.
+ return EighResult(*map(Array._new, np.linalg.eigh(x._array)))
+
+
+def eigvalsh(x: Array, /) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.linalg.eigvalsh <numpy.linalg.eigvalsh>`.
+
+ See its docstring for more information.
+ """
+ # Note: the restriction to floating-point dtypes only is different from
+ # np.linalg.eigvalsh.
+ if x.dtype not in _floating_dtypes:
+ raise TypeError('Only floating-point dtypes are allowed in eigvalsh')
+
+ return Array._new(np.linalg.eigvalsh(x._array))
+
+def inv(x: Array, /) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.linalg.inv <numpy.linalg.inv>`.
+
+ See its docstring for more information.
+ """
+ # Note: the restriction to floating-point dtypes only is different from
+ # np.linalg.inv.
+ if x.dtype not in _floating_dtypes:
+ raise TypeError('Only floating-point dtypes are allowed in inv')
+
+ return Array._new(np.linalg.inv(x._array))
+
+
+# Note: matmul is the numpy top-level namespace but not in np.linalg
+def matmul(x1: Array, x2: Array, /) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.matmul <numpy.matmul>`.
+
+ See its docstring for more information.
+ """
+ # Note: the restriction to numeric dtypes only is different from
+ # np.matmul.
+ if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
+ raise TypeError('Only numeric dtypes are allowed in matmul')
+
+ return Array._new(np.matmul(x1._array, x2._array))
+
+
+# Note: the name here is different from norm(). The array API norm is split
+# into matrix_norm and vector_norm().
+
+# The type for ord should be Optional[Union[int, float, Literal[np.inf,
+# -np.inf, 'fro', 'nuc']]], but Literal does not support floating-point
+# literals.
+def matrix_norm(x: Array, /, *, keepdims: bool = False, ord: Optional[Union[int, float, Literal['fro', 'nuc']]] = 'fro') -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.linalg.norm <numpy.linalg.norm>`.
+
+ See its docstring for more information.
+ """
+ # Note: the restriction to floating-point dtypes only is different from
+ # np.linalg.norm.
+ if x.dtype not in _floating_dtypes:
+ raise TypeError('Only floating-point dtypes are allowed in matrix_norm')
+
+ return Array._new(np.linalg.norm(x._array, axis=(-2, -1), keepdims=keepdims, ord=ord))
+
+
+def matrix_power(x: Array, n: int, /) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.matrix_power <numpy.matrix_power>`.
+
+ See its docstring for more information.
+ """
+ # Note: the restriction to floating-point dtypes only is different from
+ # np.linalg.matrix_power.
+ if x.dtype not in _floating_dtypes:
+ raise TypeError('Only floating-point dtypes are allowed for the first argument of matrix_power')
+
+ # np.matrix_power already checks if n is an integer
+ return Array._new(np.linalg.matrix_power(x._array, n))
+
+# Note: the keyword argument name rtol is different from np.linalg.matrix_rank
+def matrix_rank(x: Array, /, *, rtol: Optional[Union[float, Array]] = None) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.matrix_rank <numpy.matrix_rank>`.
+
+ See its docstring for more information.
+ """
+ # Note: this is different from np.linalg.matrix_rank, which supports 1
+ # dimensional arrays.
+ if x.ndim < 2:
+ raise np.linalg.LinAlgError("1-dimensional array given. Array must be at least two-dimensional")
+ S = np.linalg.svd(x._array, compute_uv=False)
+ if rtol is None:
+ tol = S.max(axis=-1, keepdims=True) * max(x.shape[-2:]) * np.finfo(S.dtype).eps
+ else:
+ if isinstance(rtol, Array):
+ rtol = rtol._array
+ # Note: this is different from np.linalg.matrix_rank, which does not multiply
+ # the tolerance by the largest singular value.
+ tol = S.max(axis=-1, keepdims=True)*np.asarray(rtol)[..., np.newaxis]
+ return Array._new(np.count_nonzero(S > tol, axis=-1))
+
+
+# Note: this function is new in the array API spec. Unlike transpose, it only
+# transposes the last two axes.
+def matrix_transpose(x: Array, /) -> Array:
+ if x.ndim < 2:
+ raise ValueError("x must be at least 2-dimensional for matrix_transpose")
+ return Array._new(np.swapaxes(x._array, -1, -2))
+
+# Note: outer is the numpy top-level namespace, not np.linalg
+def outer(x1: Array, x2: Array, /) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.outer <numpy.outer>`.
+
+ See its docstring for more information.
+ """
+ # Note: the restriction to numeric dtypes only is different from
+ # np.outer.
+ if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
+ raise TypeError('Only numeric dtypes are allowed in outer')
+
+ # Note: the restriction to only 1-dim arrays is different from np.outer
+ if x1.ndim != 1 or x2.ndim != 1:
+ raise ValueError('The input arrays to outer must be 1-dimensional')
+
+ return Array._new(np.outer(x1._array, x2._array))
+
+# Note: the keyword argument name rtol is different from np.linalg.pinv
+def pinv(x: Array, /, *, rtol: Optional[Union[float, Array]] = None) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.linalg.pinv <numpy.linalg.pinv>`.
+
+ See its docstring for more information.
+ """
+ # Note: the restriction to floating-point dtypes only is different from
+ # np.linalg.pinv.
+ if x.dtype not in _floating_dtypes:
+ raise TypeError('Only floating-point dtypes are allowed in pinv')
+
+ # Note: this is different from np.linalg.pinv, which does not multiply the
+ # default tolerance by max(M, N).
+ if rtol is None:
+ rtol = max(x.shape[-2:]) * np.finfo(x.dtype).eps
+ return Array._new(np.linalg.pinv(x._array, rcond=rtol))
+
+def qr(x: Array, /, *, mode: Literal['reduced', 'complete'] = 'reduced') -> QRResult:
+ """
+ Array API compatible wrapper for :py:func:`np.linalg.qr <numpy.linalg.qr>`.
+
+ See its docstring for more information.
+ """
+ # Note: the restriction to floating-point dtypes only is different from
+ # np.linalg.qr.
+ if x.dtype not in _floating_dtypes:
+ raise TypeError('Only floating-point dtypes are allowed in qr')
+
+ # Note: the return type here is a namedtuple, which is different from
+ # np.linalg.qr, which only returns a tuple.
+ return QRResult(*map(Array._new, np.linalg.qr(x._array, mode=mode)))
+
+def slogdet(x: Array, /) -> SlogdetResult:
+ """
+ Array API compatible wrapper for :py:func:`np.linalg.slogdet <numpy.linalg.slogdet>`.
+
+ See its docstring for more information.
+ """
+ # Note: the restriction to floating-point dtypes only is different from
+ # np.linalg.slogdet.
+ if x.dtype not in _floating_dtypes:
+ raise TypeError('Only floating-point dtypes are allowed in slogdet')
+
+ # Note: the return type here is a namedtuple, which is different from
+ # np.linalg.slogdet, which only returns a tuple.
+ return SlogdetResult(*map(Array._new, np.linalg.slogdet(x._array)))
+
+# Note: unlike np.linalg.solve, the array API solve() only accepts x2 as a
+# vector when it is exactly 1-dimensional. All other cases treat x2 as a stack
+# of matrices. The np.linalg.solve behavior of allowing stacks of both
+# matrices and vectors is ambiguous c.f.
+# https://github.com/numpy/numpy/issues/15349 and
+# https://github.com/data-apis/array-api/issues/285.
+
+# To workaround this, the below is the code from np.linalg.solve except
+# only calling solve1 in the exactly 1D case.
+def _solve(a, b):
+ from ..linalg.linalg import (_makearray, _assert_stacked_2d,
+ _assert_stacked_square, _commonType,
+ isComplexType, get_linalg_error_extobj,
+ _raise_linalgerror_singular)
+ from ..linalg import _umath_linalg
+
+ a, _ = _makearray(a)
+ _assert_stacked_2d(a)
+ _assert_stacked_square(a)
+ b, wrap = _makearray(b)
+ t, result_t = _commonType(a, b)
+
+ # This part is different from np.linalg.solve
+ if b.ndim == 1:
+ gufunc = _umath_linalg.solve1
+ else:
+ gufunc = _umath_linalg.solve
+
+ # This does nothing currently but is left in because it will be relevant
+ # when complex dtype support is added to the spec in 2022.
+ signature = 'DD->D' if isComplexType(t) else 'dd->d'
+ extobj = get_linalg_error_extobj(_raise_linalgerror_singular)
+ r = gufunc(a, b, signature=signature, extobj=extobj)
+
+ return wrap(r.astype(result_t, copy=False))
+
+def solve(x1: Array, x2: Array, /) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.linalg.solve <numpy.linalg.solve>`.
+
+ See its docstring for more information.
+ """
+ # Note: the restriction to floating-point dtypes only is different from
+ # np.linalg.solve.
+ if x1.dtype not in _floating_dtypes or x2.dtype not in _floating_dtypes:
+ raise TypeError('Only floating-point dtypes are allowed in solve')
+
+ return Array._new(_solve(x1._array, x2._array))
+
+def svd(x: Array, /, *, full_matrices: bool = True) -> SVDResult:
+ """
+ Array API compatible wrapper for :py:func:`np.linalg.svd <numpy.linalg.svd>`.
+
+ See its docstring for more information.
+ """
+ # Note: the restriction to floating-point dtypes only is different from
+ # np.linalg.svd.
+ if x.dtype not in _floating_dtypes:
+ raise TypeError('Only floating-point dtypes are allowed in svd')
+
+ # Note: the return type here is a namedtuple, which is different from
+ # np.svd, which only returns a tuple.
+ return SVDResult(*map(Array._new, np.linalg.svd(x._array, full_matrices=full_matrices)))
+
+# Note: svdvals is not in NumPy (but it is in SciPy). It is equivalent to
+# np.linalg.svd(compute_uv=False).
+def svdvals(x: Array, /) -> Union[Array, Tuple[Array, ...]]:
+ if x.dtype not in _floating_dtypes:
+ raise TypeError('Only floating-point dtypes are allowed in svdvals')
+ return Array._new(np.linalg.svd(x._array, compute_uv=False))
+
+# Note: tensordot is the numpy top-level namespace but not in np.linalg
+
+# Note: axes must be a tuple, unlike np.tensordot where it can be an array or array-like.
+def tensordot(x1: Array, x2: Array, /, *, axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2) -> Array:
+ # Note: the restriction to numeric dtypes only is different from
+ # np.tensordot.
+ if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
+ raise TypeError('Only numeric dtypes are allowed in tensordot')
+
+ return Array._new(np.tensordot(x1._array, x2._array, axes=axes))
+
+# Note: trace is the numpy top-level namespace, not np.linalg
+def trace(x: Array, /, *, offset: int = 0) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.trace <numpy.trace>`.
+
+ See its docstring for more information.
+ """
+ if x.dtype not in _numeric_dtypes:
+ raise TypeError('Only numeric dtypes are allowed in trace')
+ # Note: trace always operates on the last two axes, whereas np.trace
+ # operates on the first two axes by default
+ return Array._new(np.asarray(np.trace(x._array, offset=offset, axis1=-2, axis2=-1)))
+
+# Note: vecdot is not in NumPy
+def vecdot(x1: Array, x2: Array, /, *, axis: int = -1) -> Array:
+ if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
+ raise TypeError('Only numeric dtypes are allowed in vecdot')
+ ndim = max(x1.ndim, x2.ndim)
+ x1_shape = (1,)*(ndim - x1.ndim) + tuple(x1.shape)
+ x2_shape = (1,)*(ndim - x2.ndim) + tuple(x2.shape)
+ if x1_shape[axis] != x2_shape[axis]:
+ raise ValueError("x1 and x2 must have the same size along the given axis")
+
+ x1_, x2_ = np.broadcast_arrays(x1._array, x2._array)
+ x1_ = np.moveaxis(x1_, axis, -1)
+ x2_ = np.moveaxis(x2_, axis, -1)
+
+ res = x1_[..., None, :] @ x2_[..., None]
+ return Array._new(res[..., 0, 0])
+
+
+# Note: the name here is different from norm(). The array API norm is split
+# into matrix_norm and vector_norm().
+
+# The type for ord should be Optional[Union[int, float, Literal[np.inf,
+# -np.inf]]] but Literal does not support floating-point literals.
+def vector_norm(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, ord: Optional[Union[int, float]] = 2) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.linalg.norm <numpy.linalg.norm>`.
+
+ See its docstring for more information.
+ """
+ # Note: the restriction to floating-point dtypes only is different from
+ # np.linalg.norm.
+ if x.dtype not in _floating_dtypes:
+ raise TypeError('Only floating-point dtypes are allowed in norm')
+
+ # np.linalg.norm tries to do a matrix norm whenever axis is a 2-tuple or
+ # when axis=None and the input is 2-D, so to force a vector norm, we make
+ # it so the input is 1-D (for axis=None), or reshape so that norm is done
+ # on a single dimension.
+ a = x._array
+ if axis is None:
+ # Note: np.linalg.norm() doesn't handle 0-D arrays
+ a = a.ravel()
+ _axis = 0
+ elif isinstance(axis, tuple):
+ # Note: The axis argument supports any number of axes, whereas
+ # np.linalg.norm() only supports a single axis for vector norm.
+ normalized_axis = normalize_axis_tuple(axis, x.ndim)
+ rest = tuple(i for i in range(a.ndim) if i not in normalized_axis)
+ newshape = axis + rest
+ a = np.transpose(a, newshape).reshape(
+ (np.prod([a.shape[i] for i in axis], dtype=int), *[a.shape[i] for i in rest]))
+ _axis = 0
+ else:
+ _axis = axis
+
+ res = Array._new(np.linalg.norm(a, axis=_axis, ord=ord))
+
+ if keepdims:
+ # We can't reuse np.linalg.norm(keepdims) because of the reshape hacks
+ # above to avoid matrix norm logic.
+ shape = list(x.shape)
+ _axis = normalize_axis_tuple(range(x.ndim) if axis is None else axis, x.ndim)
+ for i in _axis:
+ shape[i] = 1
+ res = reshape(res, tuple(shape))
+
+ return res
+
+__all__ = ['cholesky', 'cross', 'det', 'diagonal', 'eigh', 'eigvalsh', 'inv', 'matmul', 'matrix_norm', 'matrix_power', 'matrix_rank', 'matrix_transpose', 'outer', 'pinv', 'qr', 'slogdet', 'solve', 'svd', 'svdvals', 'tensordot', 'trace', 'vecdot', 'vector_norm']
diff --git a/venv/lib/python3.9/site-packages/numpy/array_api/setup.py b/venv/lib/python3.9/site-packages/numpy/array_api/setup.py
new file mode 100644
index 00000000..c8bc2910
--- /dev/null
+++ b/venv/lib/python3.9/site-packages/numpy/array_api/setup.py
@@ -0,0 +1,12 @@
+def configuration(parent_package="", top_path=None):
+ from numpy.distutils.misc_util import Configuration
+
+ config = Configuration("array_api", parent_package, top_path)
+ config.add_subpackage("tests")
+ return config
+
+
+if __name__ == "__main__":
+ from numpy.distutils.core import setup
+
+ setup(configuration=configuration)
diff --git a/venv/lib/python3.9/site-packages/numpy/array_api/tests/__init__.py b/venv/lib/python3.9/site-packages/numpy/array_api/tests/__init__.py
new file mode 100644
index 00000000..536062e3
--- /dev/null
+++ b/venv/lib/python3.9/site-packages/numpy/array_api/tests/__init__.py
@@ -0,0 +1,7 @@
+"""
+Tests for the array API namespace.
+
+Note, full compliance with the array API can be tested with the official array API test
+suite https://github.com/data-apis/array-api-tests. This test suite primarily
+focuses on those things that are not tested by the official test suite.
+"""
diff --git a/venv/lib/python3.9/site-packages/numpy/array_api/tests/test_array_object.py b/venv/lib/python3.9/site-packages/numpy/array_api/tests/test_array_object.py
new file mode 100644
index 00000000..f6efacef
--- /dev/null
+++ b/venv/lib/python3.9/site-packages/numpy/array_api/tests/test_array_object.py
@@ -0,0 +1,375 @@
+import operator
+
+from numpy.testing import assert_raises
+import numpy as np
+import pytest
+
+from .. import ones, asarray, reshape, result_type, all, equal
+from .._array_object import Array
+from .._dtypes import (
+ _all_dtypes,
+ _boolean_dtypes,
+ _floating_dtypes,
+ _integer_dtypes,
+ _integer_or_boolean_dtypes,
+ _numeric_dtypes,
+ int8,
+ int16,
+ int32,
+ int64,
+ uint64,
+ bool as bool_,
+)
+
+
+def test_validate_index():
+ # The indexing tests in the official array API test suite test that the
+ # array object correctly handles the subset of indices that are required
+ # by the spec. But the NumPy array API implementation specifically
+ # disallows any index not required by the spec, via Array._validate_index.
+ # This test focuses on testing that non-valid indices are correctly
+ # rejected. See
+ # https://data-apis.org/array-api/latest/API_specification/indexing.html
+ # and the docstring of Array._validate_index for the exact indexing
+ # behavior that should be allowed. This does not test indices that are
+ # already invalid in NumPy itself because Array will generally just pass
+ # such indices directly to the underlying np.ndarray.
+
+ a = ones((3, 4))
+
+ # Out of bounds slices are not allowed
+ assert_raises(IndexError, lambda: a[:4])
+ assert_raises(IndexError, lambda: a[:-4])
+ assert_raises(IndexError, lambda: a[:3:-1])
+ assert_raises(IndexError, lambda: a[:-5:-1])
+ assert_raises(IndexError, lambda: a[4:])
+ assert_raises(IndexError, lambda: a[-4:])
+ assert_raises(IndexError, lambda: a[4::-1])
+ assert_raises(IndexError, lambda: a[-4::-1])
+
+ assert_raises(IndexError, lambda: a[...,:5])
+ assert_raises(IndexError, lambda: a[...,:-5])
+ assert_raises(IndexError, lambda: a[...,:5:-1])
+ assert_raises(IndexError, lambda: a[...,:-6:-1])
+ assert_raises(IndexError, lambda: a[...,5:])
+ assert_raises(IndexError, lambda: a[...,-5:])
+ assert_raises(IndexError, lambda: a[...,5::-1])
+ assert_raises(IndexError, lambda: a[...,-5::-1])
+
+ # Boolean indices cannot be part of a larger tuple index
+ assert_raises(IndexError, lambda: a[a[:,0]==1,0])
+ assert_raises(IndexError, lambda: a[a[:,0]==1,...])
+ assert_raises(IndexError, lambda: a[..., a[0]==1])
+ assert_raises(IndexError, lambda: a[[True, True, True]])
+ assert_raises(IndexError, lambda: a[(True, True, True),])
+
+ # Integer array indices are not allowed (except for 0-D)
+ idx = asarray([[0, 1]])
+ assert_raises(IndexError, lambda: a[idx])
+ assert_raises(IndexError, lambda: a[idx,])
+ assert_raises(IndexError, lambda: a[[0, 1]])
+ assert_raises(IndexError, lambda: a[(0, 1), (0, 1)])
+ assert_raises(IndexError, lambda: a[[0, 1]])
+ assert_raises(IndexError, lambda: a[np.array([[0, 1]])])
+
+ # Multiaxis indices must contain exactly as many indices as dimensions
+ assert_raises(IndexError, lambda: a[()])
+ assert_raises(IndexError, lambda: a[0,])
+ assert_raises(IndexError, lambda: a[0])
+ assert_raises(IndexError, lambda: a[:])
+
+def test_operators():
+ # For every operator, we test that it works for the required type
+ # combinations and raises TypeError otherwise
+ binary_op_dtypes = {
+ "__add__": "numeric",
+ "__and__": "integer_or_boolean",
+ "__eq__": "all",
+ "__floordiv__": "numeric",
+ "__ge__": "numeric",
+ "__gt__": "numeric",
+ "__le__": "numeric",
+ "__lshift__": "integer",
+ "__lt__": "numeric",
+ "__mod__": "numeric",
+ "__mul__": "numeric",
+ "__ne__": "all",
+ "__or__": "integer_or_boolean",
+ "__pow__": "numeric",
+ "__rshift__": "integer",
+ "__sub__": "numeric",
+ "__truediv__": "floating",
+ "__xor__": "integer_or_boolean",
+ }
+
+ # Recompute each time because of in-place ops
+ def _array_vals():
+ for d in _integer_dtypes:
+ yield asarray(1, dtype=d)
+ for d in _boolean_dtypes:
+ yield asarray(False, dtype=d)
+ for d in _floating_dtypes:
+ yield asarray(1.0, dtype=d)
+
+ for op, dtypes in binary_op_dtypes.items():
+ ops = [op]
+ if op not in ["__eq__", "__ne__", "__le__", "__ge__", "__lt__", "__gt__"]:
+ rop = "__r" + op[2:]
+ iop = "__i" + op[2:]
+ ops += [rop, iop]
+ for s in [1, 1.0, False]:
+ for _op in ops:
+ for a in _array_vals():
+ # Test array op scalar. From the spec, the following combinations
+ # are supported:
+
+ # - Python bool for a bool array dtype,
+ # - a Python int within the bounds of the given dtype for integer array dtypes,
+ # - a Python int or float for floating-point array dtypes
+
+ # We do not do bounds checking for int scalars, but rather use the default
+ # NumPy behavior for casting in that case.
+
+ if ((dtypes == "all"
+ or dtypes == "numeric" and a.dtype in _numeric_dtypes
+ or dtypes == "integer" and a.dtype in _integer_dtypes
+ or dtypes == "integer_or_boolean" and a.dtype in _integer_or_boolean_dtypes
+ or dtypes == "boolean" and a.dtype in _boolean_dtypes
+ or dtypes == "floating" and a.dtype in _floating_dtypes
+ )
+ # bool is a subtype of int, which is why we avoid
+ # isinstance here.
+ and (a.dtype in _boolean_dtypes and type(s) == bool
+ or a.dtype in _integer_dtypes and type(s) == int
+ or a.dtype in _floating_dtypes and type(s) in [float, int]
+ )):
+ # Only test for no error
+ getattr(a, _op)(s)
+ else:
+ assert_raises(TypeError, lambda: getattr(a, _op)(s))
+
+ # Test array op array.
+ for _op in ops:
+ for x in _array_vals():
+ for y in _array_vals():
+ # See the promotion table in NEP 47 or the array
+ # API spec page on type promotion. Mixed kind
+ # promotion is not defined.
+ if (x.dtype == uint64 and y.dtype in [int8, int16, int32, int64]
+ or y.dtype == uint64 and x.dtype in [int8, int16, int32, int64]
+ or x.dtype in _integer_dtypes and y.dtype not in _integer_dtypes
+ or y.dtype in _integer_dtypes and x.dtype not in _integer_dtypes
+ or x.dtype in _boolean_dtypes and y.dtype not in _boolean_dtypes
+ or y.dtype in _boolean_dtypes and x.dtype not in _boolean_dtypes
+ or x.dtype in _floating_dtypes and y.dtype not in _floating_dtypes
+ or y.dtype in _floating_dtypes and x.dtype not in _floating_dtypes
+ ):
+ assert_raises(TypeError, lambda: getattr(x, _op)(y))
+ # Ensure in-place operators only promote to the same dtype as the left operand.
+ elif (
+ _op.startswith("__i")
+ and result_type(x.dtype, y.dtype) != x.dtype
+ ):
+ assert_raises(TypeError, lambda: getattr(x, _op)(y))
+ # Ensure only those dtypes that are required for every operator are allowed.
+ elif (dtypes == "all" and (x.dtype in _boolean_dtypes and y.dtype in _boolean_dtypes
+ or x.dtype in _numeric_dtypes and y.dtype in _numeric_dtypes)
+ or (dtypes == "numeric" and x.dtype in _numeric_dtypes and y.dtype in _numeric_dtypes)
+ or dtypes == "integer" and x.dtype in _integer_dtypes and y.dtype in _numeric_dtypes
+ or dtypes == "integer_or_boolean" and (x.dtype in _integer_dtypes and y.dtype in _integer_dtypes
+ or x.dtype in _boolean_dtypes and y.dtype in _boolean_dtypes)
+ or dtypes == "boolean" and x.dtype in _boolean_dtypes and y.dtype in _boolean_dtypes
+ or dtypes == "floating" and x.dtype in _floating_dtypes and y.dtype in _floating_dtypes
+ ):
+ getattr(x, _op)(y)
+ else:
+ assert_raises(TypeError, lambda: getattr(x, _op)(y))
+
+ unary_op_dtypes = {
+ "__abs__": "numeric",
+ "__invert__": "integer_or_boolean",
+ "__neg__": "numeric",
+ "__pos__": "numeric",
+ }
+ for op, dtypes in unary_op_dtypes.items():
+ for a in _array_vals():
+ if (
+ dtypes == "numeric"
+ and a.dtype in _numeric_dtypes
+ or dtypes == "integer_or_boolean"
+ and a.dtype in _integer_or_boolean_dtypes
+ ):
+ # Only test for no error
+ getattr(a, op)()
+ else:
+ assert_raises(TypeError, lambda: getattr(a, op)())
+
+ # Finally, matmul() must be tested separately, because it works a bit
+ # different from the other operations.
+ def _matmul_array_vals():
+ for a in _array_vals():
+ yield a
+ for d in _all_dtypes:
+ yield ones((3, 4), dtype=d)
+ yield ones((4, 2), dtype=d)
+ yield ones((4, 4), dtype=d)
+
+ # Scalars always error
+ for _op in ["__matmul__", "__rmatmul__", "__imatmul__"]:
+ for s in [1, 1.0, False]:
+ for a in _matmul_array_vals():
+ if (type(s) in [float, int] and a.dtype in _floating_dtypes
+ or type(s) == int and a.dtype in _integer_dtypes):
+ # Type promotion is valid, but @ is not allowed on 0-D
+ # inputs, so the error is a ValueError
+ assert_raises(ValueError, lambda: getattr(a, _op)(s))
+ else:
+ assert_raises(TypeError, lambda: getattr(a, _op)(s))
+
+ for x in _matmul_array_vals():
+ for y in _matmul_array_vals():
+ if (x.dtype == uint64 and y.dtype in [int8, int16, int32, int64]
+ or y.dtype == uint64 and x.dtype in [int8, int16, int32, int64]
+ or x.dtype in _integer_dtypes and y.dtype not in _integer_dtypes
+ or y.dtype in _integer_dtypes and x.dtype not in _integer_dtypes
+ or x.dtype in _floating_dtypes and y.dtype not in _floating_dtypes
+ or y.dtype in _floating_dtypes and x.dtype not in _floating_dtypes
+ or x.dtype in _boolean_dtypes
+ or y.dtype in _boolean_dtypes
+ ):
+ assert_raises(TypeError, lambda: x.__matmul__(y))
+ assert_raises(TypeError, lambda: y.__rmatmul__(x))
+ assert_raises(TypeError, lambda: x.__imatmul__(y))
+ elif x.shape == () or y.shape == () or x.shape[1] != y.shape[0]:
+ assert_raises(ValueError, lambda: x.__matmul__(y))
+ assert_raises(ValueError, lambda: y.__rmatmul__(x))
+ if result_type(x.dtype, y.dtype) != x.dtype:
+ assert_raises(TypeError, lambda: x.__imatmul__(y))
+ else:
+ assert_raises(ValueError, lambda: x.__imatmul__(y))
+ else:
+ x.__matmul__(y)
+ y.__rmatmul__(x)
+ if result_type(x.dtype, y.dtype) != x.dtype:
+ assert_raises(TypeError, lambda: x.__imatmul__(y))
+ elif y.shape[0] != y.shape[1]:
+ # This one fails because x @ y has a different shape from x
+ assert_raises(ValueError, lambda: x.__imatmul__(y))
+ else:
+ x.__imatmul__(y)
+
+
+def test_python_scalar_construtors():
+ b = asarray(False)
+ i = asarray(0)
+ f = asarray(0.0)
+
+ assert bool(b) == False
+ assert int(i) == 0
+ assert float(f) == 0.0
+ assert operator.index(i) == 0
+
+ # bool/int/float should only be allowed on 0-D arrays.
+ assert_raises(TypeError, lambda: bool(asarray([False])))
+ assert_raises(TypeError, lambda: int(asarray([0])))
+ assert_raises(TypeError, lambda: float(asarray([0.0])))
+ assert_raises(TypeError, lambda: operator.index(asarray([0])))
+
+ # bool/int/float should only be allowed on arrays of the corresponding
+ # dtype
+ assert_raises(ValueError, lambda: bool(i))
+ assert_raises(ValueError, lambda: bool(f))
+
+ assert_raises(ValueError, lambda: int(b))
+ assert_raises(ValueError, lambda: int(f))
+
+ assert_raises(ValueError, lambda: float(b))
+ assert_raises(ValueError, lambda: float(i))
+
+ assert_raises(TypeError, lambda: operator.index(b))
+ assert_raises(TypeError, lambda: operator.index(f))
+
+
+def test_device_property():
+ a = ones((3, 4))
+ assert a.device == 'cpu'
+
+ assert all(equal(a.to_device('cpu'), a))
+ assert_raises(ValueError, lambda: a.to_device('gpu'))
+
+ assert all(equal(asarray(a, device='cpu'), a))
+ assert_raises(ValueError, lambda: asarray(a, device='gpu'))
+
+def test_array_properties():
+ a = ones((1, 2, 3))
+ b = ones((2, 3))
+ assert_raises(ValueError, lambda: a.T)
+
+ assert isinstance(b.T, Array)
+ assert b.T.shape == (3, 2)
+
+ assert isinstance(a.mT, Array)
+ assert a.mT.shape == (1, 3, 2)
+ assert isinstance(b.mT, Array)
+ assert b.mT.shape == (3, 2)
+
+def test___array__():
+ a = ones((2, 3), dtype=int16)
+ assert np.asarray(a) is a._array
+ b = np.asarray(a, dtype=np.float64)
+ assert np.all(np.equal(b, np.ones((2, 3), dtype=np.float64)))
+ assert b.dtype == np.float64
+
+def test_allow_newaxis():
+ a = ones(5)
+ indexed_a = a[None, :]
+ assert indexed_a.shape == (1, 5)
+
+def test_disallow_flat_indexing_with_newaxis():
+ a = ones((3, 3, 3))
+ with pytest.raises(IndexError):
+ a[None, 0, 0]
+
+def test_disallow_mask_with_newaxis():
+ a = ones((3, 3, 3))
+ with pytest.raises(IndexError):
+ a[None, asarray(True)]
+
+@pytest.mark.parametrize("shape", [(), (5,), (3, 3, 3)])
+@pytest.mark.parametrize("index", ["string", False, True])
+def test_error_on_invalid_index(shape, index):
+ a = ones(shape)
+ with pytest.raises(IndexError):
+ a[index]
+
+def test_mask_0d_array_without_errors():
+ a = ones(())
+ a[asarray(True)]
+
+@pytest.mark.parametrize(
+ "i", [slice(5), slice(5, 0), asarray(True), asarray([0, 1])]
+)
+def test_error_on_invalid_index_with_ellipsis(i):
+ a = ones((3, 3, 3))
+ with pytest.raises(IndexError):
+ a[..., i]
+ with pytest.raises(IndexError):
+ a[i, ...]
+
+def test_array_keys_use_private_array():
+ """
+ Indexing operations convert array keys before indexing the internal array
+
+ Fails when array_api array keys are not converted into NumPy-proper arrays
+ in __getitem__(). This is achieved by passing array_api arrays with 0-sized
+ dimensions, which NumPy-proper treats erroneously - not sure why!
+
+ TODO: Find and use appropriate __setitem__() case.
+ """
+ a = ones((0, 0), dtype=bool_)
+ assert a[a].shape == (0,)
+
+ a = ones((0,), dtype=bool_)
+ key = ones((0, 0), dtype=bool_)
+ with pytest.raises(IndexError):
+ a[key]
diff --git a/venv/lib/python3.9/site-packages/numpy/array_api/tests/test_creation_functions.py b/venv/lib/python3.9/site-packages/numpy/array_api/tests/test_creation_functions.py
new file mode 100644
index 00000000..be9eaa38
--- /dev/null
+++ b/venv/lib/python3.9/site-packages/numpy/array_api/tests/test_creation_functions.py
@@ -0,0 +1,142 @@
+from numpy.testing import assert_raises
+import numpy as np
+
+from .. import all
+from .._creation_functions import (
+ asarray,
+ arange,
+ empty,
+ empty_like,
+ eye,
+ full,
+ full_like,
+ linspace,
+ meshgrid,
+ ones,
+ ones_like,
+ zeros,
+ zeros_like,
+)
+from .._dtypes import float32, float64
+from .._array_object import Array
+
+
+def test_asarray_errors():
+ # Test various protections against incorrect usage
+ assert_raises(TypeError, lambda: Array([1]))
+ assert_raises(TypeError, lambda: asarray(["a"]))
+ assert_raises(ValueError, lambda: asarray([1.0], dtype=np.float16))
+ assert_raises(OverflowError, lambda: asarray(2**100))
+ # Preferably this would be OverflowError
+ # assert_raises(OverflowError, lambda: asarray([2**100]))
+ assert_raises(TypeError, lambda: asarray([2**100]))
+ asarray([1], device="cpu") # Doesn't error
+ assert_raises(ValueError, lambda: asarray([1], device="gpu"))
+
+ assert_raises(ValueError, lambda: asarray([1], dtype=int))
+ assert_raises(ValueError, lambda: asarray([1], dtype="i"))
+
+
+def test_asarray_copy():
+ a = asarray([1])
+ b = asarray(a, copy=True)
+ a[0] = 0
+ assert all(b[0] == 1)
+ assert all(a[0] == 0)
+ a = asarray([1])
+ b = asarray(a, copy=np._CopyMode.ALWAYS)
+ a[0] = 0
+ assert all(b[0] == 1)
+ assert all(a[0] == 0)
+ a = asarray([1])
+ b = asarray(a, copy=np._CopyMode.NEVER)
+ a[0] = 0
+ assert all(b[0] == 0)
+ assert_raises(NotImplementedError, lambda: asarray(a, copy=False))
+ assert_raises(NotImplementedError,
+ lambda: asarray(a, copy=np._CopyMode.IF_NEEDED))
+
+
+def test_arange_errors():
+ arange(1, device="cpu") # Doesn't error
+ assert_raises(ValueError, lambda: arange(1, device="gpu"))
+ assert_raises(ValueError, lambda: arange(1, dtype=int))
+ assert_raises(ValueError, lambda: arange(1, dtype="i"))
+
+
+def test_empty_errors():
+ empty((1,), device="cpu") # Doesn't error
+ assert_raises(ValueError, lambda: empty((1,), device="gpu"))
+ assert_raises(ValueError, lambda: empty((1,), dtype=int))
+ assert_raises(ValueError, lambda: empty((1,), dtype="i"))
+
+
+def test_empty_like_errors():
+ empty_like(asarray(1), device="cpu") # Doesn't error
+ assert_raises(ValueError, lambda: empty_like(asarray(1), device="gpu"))
+ assert_raises(ValueError, lambda: empty_like(asarray(1), dtype=int))
+ assert_raises(ValueError, lambda: empty_like(asarray(1), dtype="i"))
+
+
+def test_eye_errors():
+ eye(1, device="cpu") # Doesn't error
+ assert_raises(ValueError, lambda: eye(1, device="gpu"))
+ assert_raises(ValueError, lambda: eye(1, dtype=int))
+ assert_raises(ValueError, lambda: eye(1, dtype="i"))
+
+
+def test_full_errors():
+ full((1,), 0, device="cpu") # Doesn't error
+ assert_raises(ValueError, lambda: full((1,), 0, device="gpu"))
+ assert_raises(ValueError, lambda: full((1,), 0, dtype=int))
+ assert_raises(ValueError, lambda: full((1,), 0, dtype="i"))
+
+
+def test_full_like_errors():
+ full_like(asarray(1), 0, device="cpu") # Doesn't error
+ assert_raises(ValueError, lambda: full_like(asarray(1), 0, device="gpu"))
+ assert_raises(ValueError, lambda: full_like(asarray(1), 0, dtype=int))
+ assert_raises(ValueError, lambda: full_like(asarray(1), 0, dtype="i"))
+
+
+def test_linspace_errors():
+ linspace(0, 1, 10, device="cpu") # Doesn't error
+ assert_raises(ValueError, lambda: linspace(0, 1, 10, device="gpu"))
+ assert_raises(ValueError, lambda: linspace(0, 1, 10, dtype=float))
+ assert_raises(ValueError, lambda: linspace(0, 1, 10, dtype="f"))
+
+
+def test_ones_errors():
+ ones((1,), device="cpu") # Doesn't error
+ assert_raises(ValueError, lambda: ones((1,), device="gpu"))
+ assert_raises(ValueError, lambda: ones((1,), dtype=int))
+ assert_raises(ValueError, lambda: ones((1,), dtype="i"))
+
+
+def test_ones_like_errors():
+ ones_like(asarray(1), device="cpu") # Doesn't error
+ assert_raises(ValueError, lambda: ones_like(asarray(1), device="gpu"))
+ assert_raises(ValueError, lambda: ones_like(asarray(1), dtype=int))
+ assert_raises(ValueError, lambda: ones_like(asarray(1), dtype="i"))
+
+
+def test_zeros_errors():
+ zeros((1,), device="cpu") # Doesn't error
+ assert_raises(ValueError, lambda: zeros((1,), device="gpu"))
+ assert_raises(ValueError, lambda: zeros((1,), dtype=int))
+ assert_raises(ValueError, lambda: zeros((1,), dtype="i"))
+
+
+def test_zeros_like_errors():
+ zeros_like(asarray(1), device="cpu") # Doesn't error
+ assert_raises(ValueError, lambda: zeros_like(asarray(1), device="gpu"))
+ assert_raises(ValueError, lambda: zeros_like(asarray(1), dtype=int))
+ assert_raises(ValueError, lambda: zeros_like(asarray(1), dtype="i"))
+
+def test_meshgrid_dtype_errors():
+ # Doesn't raise
+ meshgrid()
+ meshgrid(asarray([1.], dtype=float32))
+ meshgrid(asarray([1.], dtype=float32), asarray([1.], dtype=float32))
+
+ assert_raises(ValueError, lambda: meshgrid(asarray([1.], dtype=float32), asarray([1.], dtype=float64)))
diff --git a/venv/lib/python3.9/site-packages/numpy/array_api/tests/test_data_type_functions.py b/venv/lib/python3.9/site-packages/numpy/array_api/tests/test_data_type_functions.py
new file mode 100644
index 00000000..efe3d0ab
--- /dev/null
+++ b/venv/lib/python3.9/site-packages/numpy/array_api/tests/test_data_type_functions.py
@@ -0,0 +1,19 @@
+import pytest
+
+from numpy import array_api as xp
+
+
+@pytest.mark.parametrize(
+ "from_, to, expected",
+ [
+ (xp.int8, xp.int16, True),
+ (xp.int16, xp.int8, False),
+ (xp.bool, xp.int8, False),
+ (xp.asarray(0, dtype=xp.uint8), xp.int8, False),
+ ],
+)
+def test_can_cast(from_, to, expected):
+ """
+ can_cast() returns correct result
+ """
+ assert xp.can_cast(from_, to) == expected
diff --git a/venv/lib/python3.9/site-packages/numpy/array_api/tests/test_elementwise_functions.py b/venv/lib/python3.9/site-packages/numpy/array_api/tests/test_elementwise_functions.py
new file mode 100644
index 00000000..b2fb44e7
--- /dev/null
+++ b/venv/lib/python3.9/site-packages/numpy/array_api/tests/test_elementwise_functions.py
@@ -0,0 +1,111 @@
+from inspect import getfullargspec
+
+from numpy.testing import assert_raises
+
+from .. import asarray, _elementwise_functions
+from .._elementwise_functions import bitwise_left_shift, bitwise_right_shift
+from .._dtypes import (
+ _dtype_categories,
+ _boolean_dtypes,
+ _floating_dtypes,
+ _integer_dtypes,
+)
+
+
+def nargs(func):
+ return len(getfullargspec(func).args)
+
+
+def test_function_types():
+ # Test that every function accepts only the required input types. We only
+ # test the negative cases here (error). The positive cases are tested in
+ # the array API test suite.
+
+ elementwise_function_input_types = {
+ "abs": "numeric",
+ "acos": "floating-point",
+ "acosh": "floating-point",
+ "add": "numeric",
+ "asin": "floating-point",
+ "asinh": "floating-point",
+ "atan": "floating-point",
+ "atan2": "floating-point",
+ "atanh": "floating-point",
+ "bitwise_and": "integer or boolean",
+ "bitwise_invert": "integer or boolean",
+ "bitwise_left_shift": "integer",
+ "bitwise_or": "integer or boolean",
+ "bitwise_right_shift": "integer",
+ "bitwise_xor": "integer or boolean",
+ "ceil": "numeric",
+ "cos": "floating-point",
+ "cosh": "floating-point",
+ "divide": "floating-point",
+ "equal": "all",
+ "exp": "floating-point",
+ "expm1": "floating-point",
+ "floor": "numeric",
+ "floor_divide": "numeric",
+ "greater": "numeric",
+ "greater_equal": "numeric",
+ "isfinite": "numeric",
+ "isinf": "numeric",
+ "isnan": "numeric",
+ "less": "numeric",
+ "less_equal": "numeric",
+ "log": "floating-point",
+ "logaddexp": "floating-point",
+ "log10": "floating-point",
+ "log1p": "floating-point",
+ "log2": "floating-point",
+ "logical_and": "boolean",
+ "logical_not": "boolean",
+ "logical_or": "boolean",
+ "logical_xor": "boolean",
+ "multiply": "numeric",
+ "negative": "numeric",
+ "not_equal": "all",
+ "positive": "numeric",
+ "pow": "numeric",
+ "remainder": "numeric",
+ "round": "numeric",
+ "sign": "numeric",
+ "sin": "floating-point",
+ "sinh": "floating-point",
+ "sqrt": "floating-point",
+ "square": "numeric",
+ "subtract": "numeric",
+ "tan": "floating-point",
+ "tanh": "floating-point",
+ "trunc": "numeric",
+ }
+
+ def _array_vals():
+ for d in _integer_dtypes:
+ yield asarray(1, dtype=d)
+ for d in _boolean_dtypes:
+ yield asarray(False, dtype=d)
+ for d in _floating_dtypes:
+ yield asarray(1.0, dtype=d)
+
+ for x in _array_vals():
+ for func_name, types in elementwise_function_input_types.items():
+ dtypes = _dtype_categories[types]
+ func = getattr(_elementwise_functions, func_name)
+ if nargs(func) == 2:
+ for y in _array_vals():
+ if x.dtype not in dtypes or y.dtype not in dtypes:
+ assert_raises(TypeError, lambda: func(x, y))
+ else:
+ if x.dtype not in dtypes:
+ assert_raises(TypeError, lambda: func(x))
+
+
+def test_bitwise_shift_error():
+ # bitwise shift functions should raise when the second argument is negative
+ assert_raises(
+ ValueError, lambda: bitwise_left_shift(asarray([1, 1]), asarray([1, -1]))
+ )
+ assert_raises(
+ ValueError, lambda: bitwise_right_shift(asarray([1, 1]), asarray([1, -1]))
+ )
diff --git a/venv/lib/python3.9/site-packages/numpy/array_api/tests/test_set_functions.py b/venv/lib/python3.9/site-packages/numpy/array_api/tests/test_set_functions.py
new file mode 100644
index 00000000..b8eb65d4
--- /dev/null
+++ b/venv/lib/python3.9/site-packages/numpy/array_api/tests/test_set_functions.py
@@ -0,0 +1,19 @@
+import pytest
+from hypothesis import given
+from hypothesis.extra.array_api import make_strategies_namespace
+
+from numpy import array_api as xp
+
+xps = make_strategies_namespace(xp)
+
+
+@pytest.mark.parametrize("func", [xp.unique_all, xp.unique_inverse])
+@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=xps.array_shapes()))
+def test_inverse_indices_shape(func, x):
+ """
+ Inverse indices share shape of input array
+
+ See https://github.com/numpy/numpy/issues/20638
+ """
+ out = func(x)
+ assert out.inverse_indices.shape == x.shape
diff --git a/venv/lib/python3.9/site-packages/numpy/array_api/tests/test_sorting_functions.py b/venv/lib/python3.9/site-packages/numpy/array_api/tests/test_sorting_functions.py
new file mode 100644
index 00000000..9848bbfe
--- /dev/null
+++ b/venv/lib/python3.9/site-packages/numpy/array_api/tests/test_sorting_functions.py
@@ -0,0 +1,23 @@
+import pytest
+
+from numpy import array_api as xp
+
+
+@pytest.mark.parametrize(
+ "obj, axis, expected",
+ [
+ ([0, 0], -1, [0, 1]),
+ ([0, 1, 0], -1, [1, 0, 2]),
+ ([[0, 1], [1, 1]], 0, [[1, 0], [0, 1]]),
+ ([[0, 1], [1, 1]], 1, [[1, 0], [0, 1]]),
+ ],
+)
+def test_stable_desc_argsort(obj, axis, expected):
+ """
+ Indices respect relative order of a descending stable-sort
+
+ See https://github.com/numpy/numpy/issues/20778
+ """
+ x = xp.asarray(obj)
+ out = xp.argsort(x, axis=axis, stable=True, descending=True)
+ assert xp.all(out == xp.asarray(expected))
diff --git a/venv/lib/python3.9/site-packages/numpy/array_api/tests/test_validation.py b/venv/lib/python3.9/site-packages/numpy/array_api/tests/test_validation.py
new file mode 100644
index 00000000..0dd100d1
--- /dev/null
+++ b/venv/lib/python3.9/site-packages/numpy/array_api/tests/test_validation.py
@@ -0,0 +1,27 @@
+from typing import Callable
+
+import pytest
+
+from numpy import array_api as xp
+
+
+def p(func: Callable, *args, **kwargs):
+ f_sig = ", ".join(
+ [str(a) for a in args] + [f"{k}={v}" for k, v in kwargs.items()]
+ )
+ id_ = f"{func.__name__}({f_sig})"
+ return pytest.param(func, args, kwargs, id=id_)
+
+
+@pytest.mark.parametrize(
+ "func, args, kwargs",
+ [
+ p(xp.can_cast, 42, xp.int8),
+ p(xp.can_cast, xp.int8, 42),
+ p(xp.result_type, 42),
+ ],
+)
+def test_raises_on_invalid_types(func, args, kwargs):
+ """Function raises TypeError when passed invalidly-typed inputs"""
+ with pytest.raises(TypeError):
+ func(*args, **kwargs)