Skip to content

Commit 17bfa05

Browse files
committed
allow a sequence of dtypes in the scalars(...) strategy
use it to simplify the searchorted_with_scalars test
1 parent 28f86cc commit 17bfa05

File tree

2 files changed

+6
-3
lines changed

2 files changed

+6
-3
lines changed

array_api_tests/hypothesis_helpers.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -454,9 +454,12 @@ def scalars(draw, dtypes, finite=False, **kwds):
454454
"""
455455
Strategy to generate a scalar that matches a dtype strategy
456456
457-
dtypes should be one of the shared_* dtypes strategies.
457+
dtypes should be one of the shared_* dtypes strategies or a sequence of dtypes.
458458
"""
459-
dtype = draw(dtypes)
459+
if isinstance(dtypes, Sequence):
460+
dtype = draw(sampled_from(dtypes))
461+
else:
462+
dtype = draw(dtypes)
460463
mM = kwds.pop('mM', None)
461464
if dh.is_int_dtype(dtype):
462465
if mM is None:

array_api_tests/test_searching_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,7 @@ def test_searchsorted_with_scalars(data):
316316
kw = data.draw(hh.kwargs(side=st.sampled_from(["left", "right"])))
317317

318318
# 2. draw x2, a real-valued scalar (IOW, an int or a float)
319-
x2 = data.draw(hh.scalars(st.sampled_from([xp.int32, xp.float64]), finite=True))
319+
x2 = data.draw(hh.scalars([xp.int32, xp.float64], finite=True))
320320

321321
# 3. testing: similar to test_searchsorted, modulo `out.shape == ()`
322322
repro_snippet = ph.format_snippet(

0 commit comments

Comments
 (0)