Location via proxy:   [ UP ]  
[Report a bug]   [Manage cookies]                
Skip to content

Commit cb1f865

Browse files
committed
API: adjust nD fft s param to array API
1 parent 284a2f0 commit cb1f865

File tree

2 files changed

+42
-14
lines changed

2 files changed

+42
-14
lines changed

numpy/fft/_pocketfft.py

+29-14
Original file line numberDiff line numberDiff line change
@@ -681,20 +681,24 @@ def ihfft(a, n=None, axis=-1, norm=None):
681681

682682
def _cook_nd_args(a, s=None, axes=None, invreal=0):
683683
if s is None:
684-
shapeless = 1
684+
shapeless = True
685685
if axes is None:
686686
s = list(a.shape)
687687
else:
688688
s = take(a.shape, axes)
689689
else:
690-
shapeless = 0
690+
shapeless = False
691691
s = list(s)
692692
if axes is None:
693+
if not shapeless:
694+
msg = "`axes` must not be `None` if `s` is not `None`."
695+
raise ValueError(msg)
693696
axes = list(range(-len(s), 0))
694697
if len(s) != len(axes):
695698
raise ValueError("Shape and axes have different lengths.")
696699
if invreal and shapeless:
697700
s[-1] = (a.shape[axes[-1]] - 1) * 2
701+
s = [a.shape[_a] if _s == -1 else _s for _s, _a in zip(s, axes)]
698702
return s, axes
699703

700704

@@ -730,9 +734,11 @@ def fftn(a, s=None, axes=None, norm=None):
730734
(``s[0]`` refers to axis 0, ``s[1]`` to axis 1, etc.).
731735
This corresponds to ``n`` for ``fft(x, n)``.
732736
Along any axis, if the given shape is smaller than that of the input,
733-
the input is cropped. If it is larger, the input is padded with zeros.
734-
if `s` is not given, the shape of the input along the axes specified
737+
the input is cropped. If it is larger, the input is padded with zeros.
738+
If it is ``-1``, the whole input is used (no padding/trimming).
739+
If `s` is not given, the shape of the input along the axes specified
735740
by `axes` is used.
741+
If `s` is not ``None``, `axes` must not be ``None`` either.
736742
axes : sequence of ints, optional
737743
Axes over which to compute the FFT. If not given, the last ``len(s)``
738744
axes are used, or all axes if `s` is also not specified.
@@ -842,9 +848,11 @@ def ifftn(a, s=None, axes=None, norm=None):
842848
(``s[0]`` refers to axis 0, ``s[1]`` to axis 1, etc.).
843849
This corresponds to ``n`` for ``ifft(x, n)``.
844850
Along any axis, if the given shape is smaller than that of the input,
845-
the input is cropped. If it is larger, the input is padded with zeros.
846-
if `s` is not given, the shape of the input along the axes specified
847-
by `axes` is used. See notes for issue on `ifft` zero padding.
851+
the input is cropped. If it is larger, the input is padded with zeros.
852+
If it is ``-1``, the whole input is used (no padding/trimming).
853+
If `s` is not given, the shape of the input along the axes specified
854+
by `axes` is used. See notes for issue on `ifft` zero padding.
855+
If `s` is not ``None``, `axes` must not be ``None`` either.
848856
axes : sequence of ints, optional
849857
Axes over which to compute the IFFT. If not given, the last ``len(s)``
850858
axes are used, or all axes if `s` is also not specified.
@@ -937,8 +945,9 @@ def fft2(a, s=None, axes=(-2, -1), norm=None):
937945
(``s[0]`` refers to axis 0, ``s[1]`` to axis 1, etc.).
938946
This corresponds to ``n`` for ``fft(x, n)``.
939947
Along each axis, if the given shape is smaller than that of the input,
940-
the input is cropped. If it is larger, the input is padded with zeros.
941-
if `s` is not given, the shape of the input along the axes specified
948+
the input is cropped. If it is larger, the input is padded with zeros.
949+
If it is ``-1``, the whole input is used (no padding/trimming).
950+
If `s` is not given, the shape of the input along the axes specified
942951
by `axes` is used.
943952
axes : sequence of ints, optional
944953
Axes over which to compute the FFT. If not given, the last two
@@ -1040,8 +1049,9 @@ def ifft2(a, s=None, axes=(-2, -1), norm=None):
10401049
Shape (length of each axis) of the output (``s[0]`` refers to axis 0,
10411050
``s[1]`` to axis 1, etc.). This corresponds to `n` for ``ifft(x, n)``.
10421051
Along each axis, if the given shape is smaller than that of the input,
1043-
the input is cropped. If it is larger, the input is padded with zeros.
1044-
if `s` is not given, the shape of the input along the axes specified
1052+
the input is cropped. If it is larger, the input is padded with zeros.
1053+
If it is ``-1``, the whole input is used (no padding/trimming).
1054+
If `s` is not given, the shape of the input along the axes specified
10451055
by `axes` is used. See notes for issue on `ifft` zero padding.
10461056
axes : sequence of ints, optional
10471057
Axes over which to compute the FFT. If not given, the last two
@@ -1128,9 +1138,11 @@ def rfftn(a, s=None, axes=None, norm=None):
11281138
The final element of `s` corresponds to `n` for ``rfft(x, n)``, while
11291139
for the remaining axes, it corresponds to `n` for ``fft(x, n)``.
11301140
Along any axis, if the given shape is smaller than that of the input,
1131-
the input is cropped. If it is larger, the input is padded with zeros.
1132-
if `s` is not given, the shape of the input along the axes specified
1141+
the input is cropped. If it is larger, the input is padded with zeros.
1142+
If it is ``-1``, the whole input is used (no padding/trimming).
1143+
If `s` is not given, the shape of the input along the axes specified
11331144
by `axes` is used.
1145+
If `s` is not ``None``, `axes` must not be ``None`` either.
11341146
axes : sequence of ints, optional
11351147
Axes over which to compute the FFT. If not given, the last ``len(s)``
11361148
axes are used, or all axes if `s` is also not specified.
@@ -1284,9 +1296,12 @@ def irfftn(a, s=None, axes=None, norm=None):
12841296
where ``s[-1]//2+1`` points of the input are used.
12851297
Along any axis, if the shape indicated by `s` is smaller than that of
12861298
the input, the input is cropped. If it is larger, the input is padded
1287-
with zeros. If `s` is not given, the shape of the input along the axes
1299+
with zeros.
1300+
If it is ``-1``, the whole input is used (no padding/trimming).
1301+
If `s` is not given, the shape of the input along the axes
12881302
specified by axes is used. Except for the last axis which is taken to
12891303
be ``2*(m-1)`` where ``m`` is the length of the input along that axis.
1304+
If `s` is not ``None``, `axes` must not be ``None`` either.
12901305
axes : sequence of ints, optional
12911306
Axes over which to compute the inverse FFT. If not given, the last
12921307
`len(s)` axes are used, or all axes if `s` is also not specified.

numpy/fft/tests/test_pocketfft.py

+13
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,19 @@ def test_axes(self, op):
199199
op_tr = op(np.transpose(x, a))
200200
tr_op = np.transpose(op(x, axes=a), a)
201201
assert_allclose(op_tr, tr_op, atol=1e-6)
202+
203+
@pytest.mark.parametrize("op", [np.fft.fftn, np.fft.ifftn])
204+
def test_s_negative_1(self, op):
205+
x = np.arange(100).reshape(10, 10)
206+
# should use the whole input array along the first axis
207+
assert op(x, s=(-1, 5), axes=(0, 1)).shape == (10, 5)
208+
209+
@pytest.mark.parametrize("op", [np.fft.fftn, np.fft.ifftn,
210+
np.fft.rfftn, np.fft.irfftn])
211+
def test_s_axes_none(self, op):
212+
x = np.arange(100).reshape(10, 10)
213+
with pytest.raises(ValueError):
214+
op(x, s=(-1, 5))
202215

203216
def test_all_1d_norm_preserving(self):
204217
# verify that round-trip transforms are norm-preserving

0 commit comments

Comments
 (0)