diff --git a/doc/whats-new.rst b/doc/whats-new.rst index a160d2cc84e..4757102a579 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -19,6 +19,8 @@ New Features By `Stephan Hoyer `_. - The ``h5netcdf`` engine has support for pseudo ``NETCDF4_CLASSIC`` files, meaning variables and attributes are cast to supported types. Note that the saved files won't be recognized as genuine ``NETCDF4_CLASSIC`` files until ``h5netcdf`` adds support with version 1.7.0. (:issue:`10676`, :pull:`10686`). By `David Huard `_. +- Support comparing :py:class:`DataTree` objects with :py:func:`testing.assert_allclose` (:pull:`10887`). + By `Justus Magin `_. Breaking Changes ~~~~~~~~~~~~~~~~ diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index a5a958ddcbe..925d82ef757 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -1506,13 +1506,18 @@ def __delitem__(self, key: Hashable) -> None: # https://github.com/python/mypy/issues/4266 __hash__ = None # type: ignore[assignment] - def _all_compat(self, other: Self, compat_str: str) -> bool: + def _all_compat( + self, other: Self, compat: str | Callable[[Variable, Variable], bool] + ) -> bool: """Helper function for equals and identical""" - # some stores (e.g., scipy) do not seem to preserve order, so don't - # require matching order for equality - def compat(x: Variable, y: Variable) -> bool: - return getattr(x, compat_str)(y) + if not callable(compat): + compat_str = compat + + # some stores (e.g., scipy) do not seem to preserve order, so don't + # require matching order for equality + def compat(x: Variable, y: Variable) -> bool: + return getattr(x, compat_str)(y) return self._coord_names == other._coord_names and utils.dict_equiv( self._variables, other._variables, compat=compat diff --git a/xarray/testing/assertions.py b/xarray/testing/assertions.py index 474a72da739..39e6da6a83c 100644 --- a/xarray/testing/assertions.py +++ b/xarray/testing/assertions.py @@ -239,6 +239,11 @@ def compat_variable(a, b): b = getattr(b, "variable", b) return a.dims == b.dims and (a._data is b._data or equiv(a.data, b.data)) + def compat_node(a, b): + return a.ds._coord_names == b.ds._coord_names and utils.dict_equiv( + a.variables, b.variables, compat=compat_variable + ) + if isinstance(a, Variable): allclose = compat_variable(a, b) assert allclose, formatting.diff_array_repr(a, b, compat=equiv) @@ -255,6 +260,11 @@ def compat_variable(a, b): elif isinstance(a, Coordinates): allclose = utils.dict_equiv(a.variables, b.variables, compat=compat_variable) assert allclose, formatting.diff_coords_repr(a, b, compat=equiv) + elif isinstance(a, DataTree): + allclose = utils.dict_equiv( + dict(a.subtree_with_keys), dict(b.subtree_with_keys), compat=compat_node + ) + assert allclose, formatting.diff_datatree_repr(a, b, compat=equiv) else: raise TypeError(f"{type(a)} not supported by assertion comparison") diff --git a/xarray/tests/test_assertions.py b/xarray/tests/test_assertions.py index 222a01a6628..6e18e47cc81 100644 --- a/xarray/tests/test_assertions.py +++ b/xarray/tests/test_assertions.py @@ -62,6 +62,19 @@ def test_allclose_regression() -> None: xr.Coordinates({"x": [0, 3]}), id="Coordinates", ), + pytest.param( + xr.DataTree.from_dict( + { + "/b": xr.Dataset({"a": ("x", [1e-17, 2]), "b": ("y", [-2e-18, 2])}), + } + ), + xr.DataTree.from_dict( + { + "/b": xr.Dataset({"a": ("x", [0, 2]), "b": ("y", [0, 1])}), + } + ), + id="DataTree", + ), ), ) def test_assert_allclose(obj1, obj2) -> None: