Location via proxy:   [ UP ]  
[Report a bug]   [Manage cookies]                
Skip to content
This repository was archived by the owner on Jun 10, 2020. It is now read-only.

Commit fa6b9fd

Browse files
authored
ENH: get more specific about _ArrayLike, make it public (#66)
Closes #37. Add tests to check various examples. Note that supporting __array__ also requires making _DtypeLike public too, so this does that as well.
1 parent caef625 commit fa6b9fd

File tree

5 files changed

+183
-101
lines changed

5 files changed

+183
-101
lines changed

numpy-stubs/__init__.pyi

+51-99
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ import sys
33
import datetime as dt
44

55
from numpy.core._internal import _ctypes
6+
from numpy.typing import ArrayLike, DtypeLike, _Shape, _ShapeLike
7+
68
from typing import (
79
Any,
810
ByteString,
@@ -36,69 +38,18 @@ else:
3638
from typing import SupportsBytes
3739

3840
if sys.version_info >= (3, 8):
39-
from typing import Literal
41+
from typing import Literal, Protocol
4042
else:
41-
from typing_extensions import Literal
43+
from typing_extensions import Literal, Protocol
4244

4345
# TODO: remove when the full numpy namespace is defined
4446
def __getattr__(name: str) -> Any: ...
4547

46-
_Shape = Tuple[int, ...]
47-
48-
# Anything that can be coerced to a shape tuple
49-
_ShapeLike = Union[int, Sequence[int]]
50-
51-
_DtypeLikeNested = Any # TODO: wait for support for recursive types
52-
53-
# Anything that can be coerced into numpy.dtype.
54-
# Reference: https://docs.scipy.org/doc/numpy/reference/arrays.dtypes.html
55-
_DtypeLike = Union[
56-
dtype,
57-
# default data type (float64)
58-
None,
59-
# array-scalar types and generic types
60-
type, # TODO: enumerate these when we add type hints for numpy scalars
61-
# TODO: add a protocol for anything with a dtype attribute
62-
# character codes, type strings or comma-separated fields, e.g., 'float64'
63-
str,
64-
# (flexible_dtype, itemsize)
65-
Tuple[_DtypeLikeNested, int],
66-
# (fixed_dtype, shape)
67-
Tuple[_DtypeLikeNested, _ShapeLike],
68-
# [(field_name, field_dtype, field_shape), ...]
69-
#
70-
# The type here is quite broad because NumPy accepts quite a wide
71-
# range of inputs inside the list; see the tests for some
72-
# examples.
73-
List[Any],
74-
# {'names': ..., 'formats': ..., 'offsets': ..., 'titles': ...,
75-
# 'itemsize': ...}
76-
# TODO: use TypedDict when/if it's officially supported
77-
Dict[
78-
str,
79-
Union[
80-
Sequence[str], # names
81-
Sequence[_DtypeLikeNested], # formats
82-
Sequence[int], # offsets
83-
Sequence[Union[bytes, Text, None]], # titles
84-
int, # itemsize
85-
],
86-
],
87-
# {'field1': ..., 'field2': ..., ...}
88-
Dict[str, Tuple[_DtypeLikeNested, int]],
89-
# (base_dtype, new_dtype)
90-
Tuple[_DtypeLikeNested, _DtypeLikeNested],
91-
]
92-
9348
_NdArraySubClass = TypeVar("_NdArraySubClass", bound=ndarray)
9449

95-
_ArrayLike = TypeVar("_ArrayLike")
96-
9750
class dtype:
9851
names: Optional[Tuple[str, ...]]
99-
def __init__(
100-
self, obj: _DtypeLike, align: bool = ..., copy: bool = ...
101-
) -> None: ...
52+
def __init__(self, obj: DtypeLike, align: bool = ..., copy: bool = ...) -> None: ...
10253
@property
10354
def alignment(self) -> int: ...
10455
@property
@@ -217,6 +168,7 @@ class _ArrayOrScalarCommon(
217168
def shape(self) -> _Shape: ...
218169
@property
219170
def strides(self) -> _Shape: ...
171+
def __array__(self, __dtype: DtypeLike = ...) -> ndarray: ...
220172
def __int__(self) -> int: ...
221173
def __float__(self) -> float: ...
222174
def __complex__(self) -> complex: ...
@@ -299,7 +251,7 @@ class ndarray(_ArrayOrScalarCommon, Iterable, Sized, Container):
299251
def __new__(
300252
cls,
301253
shape: Sequence[int],
302-
dtype: Union[_DtypeLike, str] = ...,
254+
dtype: DtypeLike = ...,
303255
buffer: _BufferType = ...,
304256
offset: int = ...,
305257
strides: _ShapeLike = ...,
@@ -338,7 +290,7 @@ class ndarray(_ArrayOrScalarCommon, Iterable, Sized, Container):
338290
def dumps(self) -> bytes: ...
339291
def astype(
340292
self,
341-
dtype: _DtypeLike,
293+
dtype: DtypeLike,
342294
order: str = ...,
343295
casting: str = ...,
344296
subok: bool = ...,
@@ -349,14 +301,14 @@ class ndarray(_ArrayOrScalarCommon, Iterable, Sized, Container):
349301
@overload
350302
def view(self, dtype: Type[_NdArraySubClass]) -> _NdArraySubClass: ...
351303
@overload
352-
def view(self, dtype: _DtypeLike = ...) -> ndarray: ...
304+
def view(self, dtype: DtypeLike = ...) -> ndarray: ...
353305
@overload
354306
def view(
355-
self, dtype: _DtypeLike, type: Type[_NdArraySubClass]
307+
self, dtype: DtypeLike, type: Type[_NdArraySubClass]
356308
) -> _NdArraySubClass: ...
357309
@overload
358310
def view(self, *, type: Type[_NdArraySubClass]) -> _NdArraySubClass: ...
359-
def getfield(self, dtype: Union[_DtypeLike, str], offset: int = ...) -> ndarray: ...
311+
def getfield(self, dtype: DtypeLike, offset: int = ...) -> ndarray: ...
360312
def setflags(
361313
self, write: bool = ..., align: bool = ..., uic: bool = ...
362314
) -> None: ...
@@ -501,26 +453,26 @@ class str_(character): ...
501453

502454
def array(
503455
object: object,
504-
dtype: _DtypeLike = ...,
456+
dtype: DtypeLike = ...,
505457
copy: bool = ...,
506458
subok: bool = ...,
507459
ndmin: int = ...,
508460
) -> ndarray: ...
509461
def zeros(
510-
shape: _ShapeLike, dtype: _DtypeLike = ..., order: Optional[str] = ...
462+
shape: _ShapeLike, dtype: DtypeLike = ..., order: Optional[str] = ...
511463
) -> ndarray: ...
512464
def ones(
513-
shape: _ShapeLike, dtype: _DtypeLike = ..., order: Optional[str] = ...
465+
shape: _ShapeLike, dtype: DtypeLike = ..., order: Optional[str] = ...
514466
) -> ndarray: ...
515467
def zeros_like(
516-
a: _ArrayLike,
468+
a: ArrayLike,
517469
dtype: Optional[dtype] = ...,
518470
order: str = ...,
519471
subok: bool = ...,
520472
shape: Optional[Union[int, Sequence[int]]] = ...,
521473
) -> ndarray: ...
522474
def ones_like(
523-
a: _ArrayLike,
475+
a: ArrayLike,
524476
dtype: Optional[dtype] = ...,
525477
order: str = ...,
526478
subok: bool = ...,
@@ -530,43 +482,43 @@ def full(
530482
shape: _ShapeLike, fill_value: Any, dtype: Optional[dtype] = ..., order: str = ...
531483
) -> ndarray: ...
532484
def full_like(
533-
a: _ArrayLike,
485+
a: ArrayLike,
534486
fill_value: Any,
535487
dtype: Optional[dtype] = ...,
536488
order: str = ...,
537489
subok: bool = ...,
538490
shape: Optional[_ShapeLike] = ...,
539491
) -> ndarray: ...
540492
def count_nonzero(
541-
a: _ArrayLike, axis: Optional[Union[int, Tuple[int], Tuple[int, int]]] = ...
493+
a: ArrayLike, axis: Optional[Union[int, Tuple[int], Tuple[int, int]]] = ...
542494
) -> Union[int, ndarray]: ...
543495
def isfortran(a: ndarray) -> bool: ...
544-
def argwhere(a: _ArrayLike) -> ndarray: ...
545-
def flatnonzero(a: _ArrayLike) -> ndarray: ...
546-
def correlate(a: _ArrayLike, v: _ArrayLike, mode: str = ...) -> ndarray: ...
547-
def convolve(a: _ArrayLike, v: _ArrayLike, mode: str = ...) -> ndarray: ...
548-
def outer(a: _ArrayLike, b: _ArrayLike, out: ndarray = ...) -> ndarray: ...
496+
def argwhere(a: ArrayLike) -> ndarray: ...
497+
def flatnonzero(a: ArrayLike) -> ndarray: ...
498+
def correlate(a: ArrayLike, v: ArrayLike, mode: str = ...) -> ndarray: ...
499+
def convolve(a: ArrayLike, v: ArrayLike, mode: str = ...) -> ndarray: ...
500+
def outer(a: ArrayLike, b: ArrayLike, out: ndarray = ...) -> ndarray: ...
549501
def tensordot(
550-
a: _ArrayLike,
551-
b: _ArrayLike,
502+
a: ArrayLike,
503+
b: ArrayLike,
552504
axes: Union[
553505
int, Tuple[int, int], Tuple[Tuple[int, int], ...], Tuple[List[int, int], ...]
554506
] = ...,
555507
) -> ndarray: ...
556508
def roll(
557-
a: _ArrayLike,
509+
a: ArrayLike,
558510
shift: Union[int, Tuple[int, ...]],
559511
axis: Optional[Union[int, Tuple[int, ...]]] = ...,
560512
) -> ndarray: ...
561-
def rollaxis(a: _ArrayLike, axis: int, start: int = ...) -> ndarray: ...
513+
def rollaxis(a: ArrayLike, axis: int, start: int = ...) -> ndarray: ...
562514
def moveaxis(
563515
a: ndarray,
564516
source: Union[int, Sequence[int]],
565517
destination: Union[int, Sequence[int]],
566518
) -> ndarray: ...
567519
def cross(
568-
a: _ArrayLike,
569-
b: _ArrayLike,
520+
a: ArrayLike,
521+
b: ArrayLike,
570522
axisa: int = ...,
571523
axisb: int = ...,
572524
axisc: int = ...,
@@ -581,21 +533,21 @@ def binary_repr(num: int, width: Optional[int] = ...) -> str: ...
581533
def base_repr(number: int, base: int = ..., padding: int = ...) -> str: ...
582534
def identity(n: int, dtype: Optional[dtype] = ...) -> ndarray: ...
583535
def allclose(
584-
a: _ArrayLike,
585-
b: _ArrayLike,
536+
a: ArrayLike,
537+
b: ArrayLike,
586538
rtol: float = ...,
587539
atol: float = ...,
588540
equal_nan: bool = ...,
589541
) -> bool: ...
590542
def isclose(
591-
a: _ArrayLike,
592-
b: _ArrayLike,
543+
a: ArrayLike,
544+
b: ArrayLike,
593545
rtol: float = ...,
594546
atol: float = ...,
595547
equal_nan: bool = ...,
596548
) -> Union[bool_, ndarray]: ...
597-
def array_equal(a1: _ArrayLike, a2: _ArrayLike) -> bool: ...
598-
def array_equiv(a1: _ArrayLike, a2: _ArrayLike) -> bool: ...
549+
def array_equal(a1: ArrayLike, a2: ArrayLike) -> bool: ...
550+
def array_equiv(a1: ArrayLike, a2: ArrayLike) -> bool: ...
599551

600552
#
601553
# Constants
@@ -649,7 +601,7 @@ class ufunc:
649601
def __name__(self) -> str: ...
650602
def __call__(
651603
self,
652-
*args: _ArrayLike,
604+
*args: ArrayLike,
653605
out: Optional[Union[ndarray, Tuple[ndarray, ...]]] = ...,
654606
where: Optional[ndarray] = ...,
655607
# The list should be a list of tuples of ints, but since we
@@ -664,7 +616,7 @@ class ufunc:
664616
casting: str = ...,
665617
# TODO: make this precise when we can use Literal.
666618
order: Optional[str] = ...,
667-
dtype: Optional[_DtypeLike] = ...,
619+
dtype: DtypeLike = ...,
668620
subok: bool = ...,
669621
signature: Union[str, Tuple[str]] = ...,
670622
# In reality this should be a length of list 3 containing an
@@ -876,56 +828,56 @@ def take(
876828
) -> _ScalarNumpy: ...
877829
@overload
878830
def take(
879-
a: _ArrayLike,
831+
a: ArrayLike,
880832
indices: int,
881833
axis: Optional[int] = ...,
882834
out: Optional[ndarray] = ...,
883835
mode: _Mode = ...,
884836
) -> _ScalarNumpy: ...
885837
@overload
886838
def take(
887-
a: _ArrayLike,
839+
a: ArrayLike,
888840
indices: _ArrayLikeIntOrBool,
889841
axis: Optional[int] = ...,
890842
out: Optional[ndarray] = ...,
891843
mode: _Mode = ...,
892844
) -> Union[_ScalarNumpy, ndarray]: ...
893-
def reshape(a: _ArrayLike, newshape: _ShapeLike, order: _Order = ...) -> ndarray: ...
845+
def reshape(a: ArrayLike, newshape: _ShapeLike, order: _Order = ...) -> ndarray: ...
894846
@overload
895847
def choose(
896848
a: _ScalarIntOrBool,
897-
choices: Union[Sequence[_ArrayLike], ndarray],
849+
choices: Union[Sequence[ArrayLike], ndarray],
898850
out: Optional[ndarray] = ...,
899851
mode: _Mode = ...,
900852
) -> _ScalarIntOrBool: ...
901853
@overload
902854
def choose(
903855
a: _IntOrBool,
904-
choices: Union[Sequence[_ArrayLike], ndarray],
856+
choices: Union[Sequence[ArrayLike], ndarray],
905857
out: Optional[ndarray] = ...,
906858
mode: _Mode = ...,
907859
) -> Union[integer, bool_]: ...
908860
@overload
909861
def choose(
910862
a: _ArrayLikeIntOrBool,
911-
choices: Union[Sequence[_ArrayLike], ndarray],
863+
choices: Union[Sequence[ArrayLike], ndarray],
912864
out: Optional[ndarray] = ...,
913865
mode: _Mode = ...,
914866
) -> ndarray: ...
915867
def repeat(
916-
a: _ArrayLike, repeats: _ArrayLikeIntOrBool, axis: Optional[int] = ...
868+
a: ArrayLike, repeats: _ArrayLikeIntOrBool, axis: Optional[int] = ...
917869
) -> ndarray: ...
918870
def put(
919-
a: ndarray, ind: _ArrayLikeIntOrBool, v: _ArrayLike, mode: _Mode = ...
871+
a: ndarray, ind: _ArrayLikeIntOrBool, v: ArrayLike, mode: _Mode = ...
920872
) -> None: ...
921873
def swapaxes(
922-
a: Union[Sequence[_ArrayLike], ndarray], axis1: int, axis2: int
874+
a: Union[Sequence[ArrayLike], ndarray], axis1: int, axis2: int
923875
) -> ndarray: ...
924876
def transpose(
925-
a: _ArrayLike, axes: Union[None, Sequence[int], ndarray] = ...
877+
a: ArrayLike, axes: Union[None, Sequence[int], ndarray] = ...
926878
) -> ndarray: ...
927879
def partition(
928-
a: _ArrayLike,
880+
a: ArrayLike,
929881
kth: _ArrayLikeIntOrBool,
930882
axis: Optional[int] = ...,
931883
kind: _PartitionKind = ...,
@@ -949,20 +901,20 @@ def argpartition(
949901
) -> ndarray: ...
950902
@overload
951903
def argpartition(
952-
a: _ArrayLike,
904+
a: ArrayLike,
953905
kth: _ArrayLikeIntOrBool,
954906
axis: Optional[int] = ...,
955907
kind: _PartitionKind = ...,
956908
order: Union[None, str, Sequence[str]] = ...,
957909
) -> ndarray: ...
958910
def sort(
959-
a: Union[Sequence[_ArrayLike], ndarray],
911+
a: Union[Sequence[ArrayLike], ndarray],
960912
axis: Optional[int] = ...,
961913
kind: Optional[_SortKind] = ...,
962914
order: Union[None, str, Sequence[str]] = ...,
963915
) -> ndarray: ...
964916
def argsort(
965-
a: Union[Sequence[_ArrayLike], ndarray],
917+
a: Union[Sequence[ArrayLike], ndarray],
966918
axis: Optional[int] = ...,
967919
kind: Optional[_SortKind] = ...,
968920
order: Union[None, str, Sequence[str]] = ...,

0 commit comments

Comments
 (0)