From a8c9ff1b2a3c825104e5bf19e74346bee8a6516a Mon Sep 17 00:00:00 2001 From: erogluorhan Date: Thu, 16 Apr 2026 12:56:19 -0600 Subject: [PATCH 1/4] Fix UxDataset constructor and to_xarray() --- uxarray/core/dataset.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/uxarray/core/dataset.py b/uxarray/core/dataset.py index 3dc1feffc..3f9f3fc64 100644 --- a/uxarray/core/dataset.py +++ b/uxarray/core/dataset.py @@ -90,8 +90,19 @@ def __init__( else: self._uxgrid = uxgrid + # As of xarray's 2026.4.0, `xr.Dataset(xr.Dataset)` is prohibited; + # hence this check, i.e. If we get `xr.Dataset` as input, use its `data_vars` + # as `dict` and + if args and isinstance(args[0], xr.Dataset): + ds = args[0] + args = (dict(ds.data_vars),) + args[1:] + kwargs.setdefault("coords", dict(ds.coords)) + kwargs.setdefault("attrs", ds.attrs) + super().__init__(*args, **kwargs) + # super().__init__(*args, **kwargs) + # declare plotting accessor plot = UncachedAccessor(UxDatasetPlotAccessor) remap = UncachedAccessor(RemapAccessor) @@ -627,9 +638,9 @@ def to_xarray(self, grid_format: str = "UGRID") -> xr.Dataset: """ if grid_format == "HEALPix": ds = self.rename_dims({"n_face": "cell"}) - return xr.Dataset(ds) + return xr.Dataset(ds.data_vars, coords=ds.coords, attrs=ds.attrs) - return xr.Dataset(self) + return xr.Dataset(self.data_vars, coords=self.coords, attrs=self.attrs) def get_dual(self): """Compute the dual mesh for a dataset, returns a new dataset object. From 78d612317d32587d5b15c5659ac9b47ece223371 Mon Sep 17 00:00:00 2001 From: erogluorhan Date: Thu, 16 Apr 2026 13:01:10 -0600 Subject: [PATCH 2/4] Remove forgotten line --- uxarray/core/dataset.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/uxarray/core/dataset.py b/uxarray/core/dataset.py index 3f9f3fc64..aad61c983 100644 --- a/uxarray/core/dataset.py +++ b/uxarray/core/dataset.py @@ -91,18 +91,18 @@ def __init__( self._uxgrid = uxgrid # As of xarray's 2026.4.0, `xr.Dataset(xr.Dataset)` is prohibited; - # hence this check, i.e. If we get `xr.Dataset` as input, use its `data_vars` - # as `dict` and + # hence this check, i.e. if we get `xr.Dataset` as input, use its `data_vars` + # as `dict` and handle `coords` and `attrs` properly as well if args and isinstance(args[0], xr.Dataset): ds = args[0] + # Replacee only args[0], `ds`, with `ds.data_vars` as `dict` args = (dict(ds.data_vars),) + args[1:] + # Set `coords` and `attrs` only if they are not explicitly provided kwargs.setdefault("coords", dict(ds.coords)) kwargs.setdefault("attrs", ds.attrs) super().__init__(*args, **kwargs) - # super().__init__(*args, **kwargs) - # declare plotting accessor plot = UncachedAccessor(UxDatasetPlotAccessor) remap = UncachedAccessor(RemapAccessor) From 0e80f8abbfe75a649bb7e8980cf48dd8d4083f50 Mon Sep 17 00:00:00 2001 From: erogluorhan Date: Fri, 17 Apr 2026 11:51:02 -0600 Subject: [PATCH 3/4] Address @rajeeja's pointed case --- uxarray/core/dataset.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/uxarray/core/dataset.py b/uxarray/core/dataset.py index aad61c983..fe0be4cbe 100644 --- a/uxarray/core/dataset.py +++ b/uxarray/core/dataset.py @@ -97,9 +97,16 @@ def __init__( ds = args[0] # Replacee only args[0], `ds`, with `ds.data_vars` as `dict` args = (dict(ds.data_vars),) + args[1:] - # Set `coords` and `attrs` only if they are not explicitly provided - kwargs.setdefault("coords", dict(ds.coords)) - kwargs.setdefault("attrs", ds.attrs) + # coords not passed positionally + if len(args) < 2: + kwargs.setdefault( + "coords", dict(ds.coords) + ) # Set it as kwarg only if not explicitly provided + # attrs not passed positionally + if len(args) < 3: + kwargs.setdefault( + "attrs", ds.attrs + ) # Set it as kwarg only if not explicitly provided super().__init__(*args, **kwargs) From ab09d5ae4076933258498954a631b29b9f795e97 Mon Sep 17 00:00:00 2001 From: erogluorhan Date: Fri, 17 Apr 2026 11:57:45 -0600 Subject: [PATCH 4/4] Add test case --- test/core/test_dataset.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/test/core/test_dataset.py b/test/core/test_dataset.py index bf793d3b4..a93a5ffd6 100644 --- a/test/core/test_dataset.py +++ b/test/core/test_dataset.py @@ -99,3 +99,16 @@ def test_sel_method_forwarded(gridpath, datasetpath): nearest["time"].values, np.array(uxds["time"].values[2], dtype="datetime64[ns]"), ) + +def test_uxdataset_init_from_xarray_dataset(): + ds = xr.Dataset( + data_vars={"a": ("x", [1, 2])}, + coords={"x": [10, 20]}, + attrs={"source": "testing"}, + ) + + uxds = ux.UxDataset(ds) + + assert "a" in uxds.data_vars + assert "x" in uxds.coords + assert uxds.attrs["source"] == "testing"