Skip to content
1 change: 1 addition & 0 deletions changes/3547.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add a `array.target_shard_size_bytes` to [`zarr.config`][] to allow users to set a maximum number of bytes per-shard when `shards="auto"` in, for example, [`zarr.create_array`][].
2 changes: 2 additions & 0 deletions docs/user-guide/performance.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ z6 = zarr.create_array(store={}, shape=(10000, 10000, 1000), shards=(1000, 1000,
print(z6.info)
```

`shards` can be `"auto"` as well, in which case the `array.target_shard_size_bytes` setting can be used to control the size of shards (i.e., the size of the shard will be as close to without being bigger than `target_shard_size_bytes`); otherwise, a default is used.

### Chunk memory layout

The order of bytes **within each chunk** of an array can be changed via the
Expand Down
58 changes: 53 additions & 5 deletions src/zarr/core/chunk_grids.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import numpy as np

import zarr
from zarr.abc.metadata import Metadata
from zarr.core.common import (
JSON,
Expand Down Expand Up @@ -202,6 +203,43 @@ def get_nchunks(self, array_shape: tuple[int, ...]) -> int:
)


def _guess_num_chunks_per_axis_shard(
chunk_shape: tuple[int, ...], item_size: int, max_bytes: int, array_shape: tuple[int, ...]
) -> int:
"""Generate the number of chunks per axis to hit a target max byte size for a shard.

For example, for a (2,2,2) chunk size and item size 4, maximum bytes of 256 would return 2.
In other words the shard would be a (2,2,2) grid of (2,2,2) chunks
i.e., prod(chunk_shape) * (returned_val * len(chunk_shape)) * item_size = 256 bytes.

Parameters
----------
chunk_shape
The shape of the (inner) chunks.
item_size
The item size of the data i.e., 2 for uint16.
max_bytes
The maximum number of bytes per shard to allow.
array_shape
The shape of the underlying array.

Returns
-------
The number of chunks per axis.
"""
bytes_per_chunk = np.prod(chunk_shape) * item_size
if max_bytes < bytes_per_chunk:
return 1
num_axes = len(chunk_shape)
chunks_per_shard = 1
# First check for byte size, second check to make sure we don't go bigger than the array shape
while (bytes_per_chunk * ((chunks_per_shard + 1) ** num_axes)) <= max_bytes and all(
c * (chunks_per_shard + 1) <= a for c, a in zip(chunk_shape, array_shape, strict=True)
):
chunks_per_shard += 1
return chunks_per_shard


def _auto_partition(
*,
array_shape: tuple[int, ...],
Expand Down Expand Up @@ -237,12 +275,22 @@ def _auto_partition(
stacklevel=2,
)
_shards_out = ()
target_shard_size_bytes = zarr.config.get("array.target_shard_size_bytes", None)
num_chunks_per_shard_axis = (
_guess_num_chunks_per_axis_shard(
chunk_shape=_chunks_out,
item_size=item_size,
max_bytes=target_shard_size_bytes,
array_shape=array_shape,
)
if (has_auto_shard := (target_shard_size_bytes is not None))
else 2
)
for a_shape, c_shape in zip(array_shape, _chunks_out, strict=True):
# TODO: make a better heuristic than this.
# for each axis, if there are more than 8 chunks along that axis, then put
# 2 chunks in each shard for that axis.
if a_shape // c_shape > 8:
_shards_out += (c_shape * 2,)
# The previous heuristic was `a_shape // c_shape > 8` and now, with target_shard_size_bytes, we only check that the shard size is less than the array size.
can_shard_axis = a_shape // c_shape > 8 if not has_auto_shard else True
if can_shard_axis:
_shards_out += (c_shape * num_chunks_per_shard_axis,)
else:
_shards_out += (c_shape,)
elif isinstance(shard_shape, dict):
Expand Down
1 change: 1 addition & 0 deletions src/zarr/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def enable_gpu(self) -> ConfigSet:
"array": {
"order": "C",
"write_empty_chunks": False,
"target_shard_size_bytes": None,
},
"async": {"concurrency": 10, "timeout": None},
"threading": {"max_workers": None},
Expand Down
53 changes: 38 additions & 15 deletions tests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -966,33 +966,56 @@ async def test_nbytes(


@pytest.mark.parametrize(
("array_shape", "chunk_shape"),
[((256,), (2,))],
("array_shape", "chunk_shape", "target_shard_size_bytes", "expected_shards"),
[
pytest.param(
(256, 256),
(32, 32),
129 * 129,
(128, 128),
id="2d_chunking_max_byes_does_not_evenly_divide",
),
pytest.param(
(256, 256), (32, 32), 64 * 64, (64, 64), id="2d_chunking_max_byes_evenly_divides"
),
pytest.param(
(256, 256),
(64, 32),
128 * 128,
(128, 64),
id="2d_non_square_chunking_max_byes_evenly_divides",
),
pytest.param((256,), (2,), 255, (254,), id="max_bytes_just_below_array_shape"),
pytest.param((256,), (2,), 256, (256,), id="max_bytes_equal_to_array_shape"),
pytest.param((256,), (2,), 16, (16,), id="max_bytes_normal_val"),
pytest.param((256,), (2,), 2, (2,), id="max_bytes_same_as_chunk"),
pytest.param((256,), (2,), 1, (2,), id="max_bytes_less_than_chunk"),
pytest.param((256,), (2,), None, (4,), id="use_default_auto_setting"),
pytest.param((4,), (2,), None, (2,), id="small_array_shape_does_not_shard"),
],
)
def test_auto_partition_auto_shards(
array_shape: tuple[int, ...], chunk_shape: tuple[int, ...]
array_shape: tuple[int, ...],
chunk_shape: tuple[int, ...],
target_shard_size_bytes: int | None,
expected_shards: tuple[int, ...],
) -> None:
"""
Test that automatically picking a shard size returns a tuple of 2 * the chunk shape for any axis
where there are 8 or more chunks.
"""
dtype = np.dtype("uint8")
expected_shards: tuple[int, ...] = ()
for cs, a_len in zip(chunk_shape, array_shape, strict=False):
if a_len // cs >= 8:
expected_shards += (2 * cs,)
else:
expected_shards += (cs,)
with pytest.warns(
ZarrUserWarning,
match="Automatic shard shape inference is experimental and may change without notice.",
):
auto_shards, _ = _auto_partition(
array_shape=array_shape,
chunk_shape=chunk_shape,
shard_shape="auto",
item_size=dtype.itemsize,
)
with zarr.config.set({"array.target_shard_size_bytes": target_shard_size_bytes}):
auto_shards, _ = _auto_partition(
array_shape=array_shape,
chunk_shape=chunk_shape,
shard_shape="auto",
item_size=dtype.itemsize,
)
assert auto_shards == expected_shards


Expand Down
1 change: 1 addition & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def test_config_defaults_set() -> None:
"array": {
"order": "C",
"write_empty_chunks": False,
"target_shard_size_bytes": None,
},
"async": {"concurrency": 10, "timeout": None},
"threading": {"max_workers": None},
Expand Down