diff --git a/test/core/test_dataset.py b/test/core/test_dataset.py index bf793d3b4..a93a5ffd6 100644 --- a/test/core/test_dataset.py +++ b/test/core/test_dataset.py @@ -99,3 +99,16 @@ def test_sel_method_forwarded(gridpath, datasetpath): nearest["time"].values, np.array(uxds["time"].values[2], dtype="datetime64[ns]"), ) + +def test_uxdataset_init_from_xarray_dataset(): + ds = xr.Dataset( + data_vars={"a": ("x", [1, 2])}, + coords={"x": [10, 20]}, + attrs={"source": "testing"}, + ) + + uxds = ux.UxDataset(ds) + + assert "a" in uxds.data_vars + assert "x" in uxds.coords + assert uxds.attrs["source"] == "testing" diff --git a/uxarray/core/dataset.py b/uxarray/core/dataset.py index 3dc1feffc..fe0be4cbe 100644 --- a/uxarray/core/dataset.py +++ b/uxarray/core/dataset.py @@ -90,6 +90,24 @@ def __init__( else: self._uxgrid = uxgrid + # As of xarray's 2026.4.0, `xr.Dataset(xr.Dataset)` is prohibited; + # hence this check, i.e. if we get `xr.Dataset` as input, use its `data_vars` + # as `dict` and handle `coords` and `attrs` properly as well + if args and isinstance(args[0], xr.Dataset): + ds = args[0] + # Replacee only args[0], `ds`, with `ds.data_vars` as `dict` + args = (dict(ds.data_vars),) + args[1:] + # coords not passed positionally + if len(args) < 2: + kwargs.setdefault( + "coords", dict(ds.coords) + ) # Set it as kwarg only if not explicitly provided + # attrs not passed positionally + if len(args) < 3: + kwargs.setdefault( + "attrs", ds.attrs + ) # Set it as kwarg only if not explicitly provided + super().__init__(*args, **kwargs) # declare plotting accessor @@ -627,9 +645,9 @@ def to_xarray(self, grid_format: str = "UGRID") -> xr.Dataset: """ if grid_format == "HEALPix": ds = self.rename_dims({"n_face": "cell"}) - return xr.Dataset(ds) + return xr.Dataset(ds.data_vars, coords=ds.coords, attrs=ds.attrs) - return xr.Dataset(self) + return xr.Dataset(self.data_vars, coords=self.coords, attrs=self.attrs) def get_dual(self): """Compute the dual mesh for a dataset, returns a new dataset object.