Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions xarray/backends/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1097,8 +1097,12 @@ def open_datatree(
Additional keyword arguments passed on to the engine open function.
For example:

- 'group': path to the group in the given file to open as the root group as
a str.
- 'group': path to the group in the given file to open as the root
group as a str. If the string contains glob metacharacters
(``*``, ``?``, ``[``), it is interpreted as a pattern and only
groups whose paths match are loaded (along with their ancestors).
For example, ``group="*/sweep_0"`` loads every ``sweep_0`` one
level deep while skipping sibling groups.
- 'lock': resource lock to use when reading data from disk. Only
relevant when using dask or another form of parallelism. By default,
appropriate locks are chosen to safely read and write files with the
Expand Down Expand Up @@ -1344,8 +1348,12 @@ def open_groups(
Additional keyword arguments passed on to the engine open function.
For example:

- 'group': path to the group in the given file to open as the root group as
a str.
- 'group': path to the group in the given file to open as the root
group as a str. If the string contains glob metacharacters
(``*``, ``?``, ``[``), it is interpreted as a pattern and only
groups whose paths match are loaded (along with their ancestors).
For example, ``group="*/sweep_0"`` loads every ``sweep_0`` one
level deep while skipping sibling groups.
- 'lock': resource lock to use when reading data from disk. Only
relevant when using dask or another form of parallelism. By default,
appropriate locks are chosen to safely read and write files with the
Expand Down
31 changes: 31 additions & 0 deletions xarray/backends/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,37 @@ def _iter_nc_groups(root, parent="/"):
yield from _iter_nc_groups(group, parent=gpath)


def _is_glob_pattern(pattern: str) -> bool:
return any(c in pattern for c in "*?[")


def _filter_group_paths(group_paths: Iterable[str], pattern: str) -> list[str]:
from xarray.core.treenode import NodePath

matched: set[str] = {"/"}
for path in group_paths:
np_ = NodePath(path)
if np_.match(pattern):
matched.add(path)
for parent in np_.parents:
p = str(parent)
if p:
matched.add(p)

return [p for p in group_paths if p in matched]


def _resolve_group_and_filter(
group: str | None,
all_group_paths: list[str],
) -> tuple[str | None, list[str]]:
if group is None:
return None, all_group_paths
if _is_glob_pattern(group):
return None, _filter_group_paths(all_group_paths, group)
return group, all_group_paths


def find_root_and_group(ds):
"""Find the root and group name of a netCDF4/h5netcdf dataset."""
hierarchy = ()
Expand Down
22 changes: 15 additions & 7 deletions xarray/backends/h5netcdf_.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,7 +635,11 @@ def open_groups_as_dict(
driver_kwds=None,
**kwargs,
) -> dict[str, Dataset]:
from xarray.backends.common import _iter_nc_groups
from xarray.backends.common import (
_is_glob_pattern,
_iter_nc_groups,
_resolve_group_and_filter,
)
from xarray.core.treenode import NodePath
from xarray.core.utils import close_on_error

Expand All @@ -644,10 +648,12 @@ def open_groups_as_dict(
emit_phony_dims_warning, phony_dims = _check_phony_dims(phony_dims)

filename_or_obj = _normalize_filename_or_obj(filename_or_obj)

effective_group = None if (group and _is_glob_pattern(group)) else group
store = H5NetCDFStore.open(
filename_or_obj,
format=format,
group=group,
group=effective_group,
lock=lock,
invalid_netcdf=invalid_netcdf,
phony_dims=phony_dims,
Expand All @@ -656,15 +662,17 @@ def open_groups_as_dict(
driver_kwds=driver_kwds,
)

# Check for a group and make it a parent if it exists
if group:
parent = NodePath("/") / NodePath(group)
if effective_group:
parent = NodePath("/") / NodePath(effective_group)
else:
parent = NodePath("/")

manager = store._manager
all_group_paths = list(_iter_nc_groups(store.ds, parent=parent))
_, filtered_paths = _resolve_group_and_filter(group, all_group_paths)

groups_dict = {}
for path_group in _iter_nc_groups(store.ds, parent=parent):
for path_group in filtered_paths:
group_store = H5NetCDFStore(manager, group=path_group, **kwargs)
store_entrypoint = StoreBackendEntrypoint()
with close_on_error(group_store):
Expand All @@ -679,7 +687,7 @@ def open_groups_as_dict(
decode_timedelta=decode_timedelta,
)

if group:
if effective_group:
group_name = str(NodePath(path_group).relative_to(parent))
else:
group_name = str(NodePath(path_group))
Expand Down
22 changes: 15 additions & 7 deletions xarray/backends/netCDF4_.py
Original file line number Diff line number Diff line change
Expand Up @@ -859,13 +859,19 @@ def open_groups_as_dict(
autoclose=False,
**kwargs,
) -> dict[str, Dataset]:
from xarray.backends.common import _iter_nc_groups
from xarray.backends.common import (
_is_glob_pattern,
_iter_nc_groups,
_resolve_group_and_filter,
)
from xarray.core.treenode import NodePath

filename_or_obj = _normalize_path(filename_or_obj)

effective_group = None if (group and _is_glob_pattern(group)) else group
store = NetCDF4DataStore.open(
filename_or_obj,
group=group,
group=effective_group,
format=format,
clobber=clobber,
diskless=diskless,
Expand All @@ -875,15 +881,17 @@ def open_groups_as_dict(
autoclose=autoclose,
)

# Check for a group and make it a parent if it exists
if group:
parent = NodePath("/") / NodePath(group)
if effective_group:
parent = NodePath("/") / NodePath(effective_group)
else:
parent = NodePath("/")

manager = store._manager
all_group_paths = list(_iter_nc_groups(store.ds, parent=parent))
_, filtered_paths = _resolve_group_and_filter(group, all_group_paths)

groups_dict = {}
for path_group in _iter_nc_groups(store.ds, parent=parent):
for path_group in filtered_paths:
group_store = NetCDF4DataStore(manager, group=path_group, **kwargs)
store_entrypoint = StoreBackendEntrypoint()
with close_on_error(group_store):
Expand All @@ -897,7 +905,7 @@ def open_groups_as_dict(
use_cftime=use_cftime,
decode_timedelta=decode_timedelta,
)
if group:
if effective_group:
group_name = str(NodePath(path_group).relative_to(parent))
else:
group_name = str(NodePath(path_group))
Expand Down
58 changes: 43 additions & 15 deletions xarray/backends/zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1900,10 +1900,13 @@ def open_datatree(
zarr_format=None,
max_concurrency: int | None = None,
) -> DataTree:
from xarray.backends.common import _is_glob_pattern, _resolve_group_and_filter

filename_or_obj = _normalize_path(filename_or_obj)

if group:
parent = str(NodePath("/") / NodePath(group))
effective_group = None if (group and _is_glob_pattern(group)) else group
if effective_group:
parent = str(NodePath("/") / NodePath(effective_group))
else:
parent = str(NodePath("/"))

Expand Down Expand Up @@ -1964,8 +1967,11 @@ def open_datatree(
zarr_version=zarr_version,
zarr_format=zarr_format,
)
all_paths = list(stores.keys())
_, filtered_paths = _resolve_group_and_filter(group, all_paths)
groups_dict = {}
for path_group, store in stores.items():
for path_group in filtered_paths:
store = stores[path_group]
store_entrypoint = StoreBackendEntrypoint()
with close_on_error(store):
group_ds = store_entrypoint.open_dataset(
Expand All @@ -1978,7 +1984,7 @@ def open_datatree(
use_cftime=use_cftime,
decode_timedelta=decode_timedelta,
)
if group:
if effective_group:
group_name = str(NodePath(path_group).relative_to(parent))
else:
group_name = str(NodePath(path_group))
Expand Down Expand Up @@ -2045,6 +2051,16 @@ async def _open_datatree_from_stores_async(
if parent_path in group_children:
group_children[parent_path][child_name] = member

# Filter groups when glob pattern is used
from xarray.backends.common import _resolve_group_and_filter

effective_group, filtered_paths = _resolve_group_and_filter(
group, list(group_async.keys())
)
filtered_set = set(filtered_paths)
group_async = {k: v for k, v in group_async.items() if k in filtered_set}
group_children = {k: v for k, v in group_children.items() if k in filtered_set}

# Phase 2: Open each group — wrap async objects, run CPU decode in threads.
async def open_one(path_group: str) -> tuple[str, Dataset]:
async_grp = group_async[path_group]
Expand Down Expand Up @@ -2091,7 +2107,7 @@ def _cpu_open():
)

ds = await loop.run_in_executor(executor, _cpu_open)
if group:
if effective_group:
group_name = str(NodePath(path_group).relative_to(parent))
else:
group_name = str(NodePath(path_group))
Expand Down Expand Up @@ -2132,11 +2148,13 @@ def open_groups_as_dict(
zarr_version=None,
zarr_format=None,
) -> dict[str, Dataset]:
from xarray.backends.common import _is_glob_pattern, _resolve_group_and_filter

filename_or_obj = _normalize_path(filename_or_obj)

# Check for a group and make it a parent if it exists
if group:
parent = str(NodePath("/") / NodePath(group))
effective_group = None if (group and _is_glob_pattern(group)) else group
if effective_group:
parent = str(NodePath("/") / NodePath(effective_group))
else:
parent = str(NodePath("/"))

Expand All @@ -2153,8 +2171,11 @@ def open_groups_as_dict(
zarr_format=zarr_format,
)

_, filtered_paths = _resolve_group_and_filter(group, list(stores.keys()))

groups_dict = {}
for path_group, store in stores.items():
for path_group in filtered_paths:
store = stores[path_group]
store_entrypoint = StoreBackendEntrypoint()

with close_on_error(store):
Expand All @@ -2168,7 +2189,7 @@ def open_groups_as_dict(
use_cftime=use_cftime,
decode_timedelta=decode_timedelta,
)
if group:
if effective_group:
group_name = str(NodePath(path_group).relative_to(parent))
else:
group_name = str(NodePath(path_group))
Expand Down Expand Up @@ -2200,11 +2221,13 @@ async def open_groups_as_dict_async(
This mirrors open_groups_as_dict but parallelizes per-group Dataset opening,
which can significantly reduce latency on high-RTT object stores.
"""
from xarray.backends.common import _is_glob_pattern, _resolve_group_and_filter

filename_or_obj = _normalize_path(filename_or_obj)

# Determine parent group path context
if group:
parent = str(NodePath("/") / NodePath(group))
effective_group = None if (group and _is_glob_pattern(group)) else group
if effective_group:
parent = str(NodePath("/") / NodePath(effective_group))
else:
parent = str(NodePath("/"))

Expand All @@ -2221,6 +2244,9 @@ async def open_groups_as_dict_async(
zarr_format=zarr_format,
)

_, filtered_paths = _resolve_group_and_filter(group, list(stores.keys()))
filtered_set = set(filtered_paths)

loop = asyncio.get_running_loop()
max_workers = min(len(stores), 10) if stores else 1
executor = ThreadPoolExecutor(
Expand All @@ -2244,15 +2270,17 @@ def _load_sync():
)

ds = await loop.run_in_executor(executor, _load_sync)
if group:
if effective_group:
group_name = str(NodePath(path_group).relative_to(parent))
else:
group_name = str(NodePath(path_group))
return group_name, ds

try:
tasks = [
open_one(path_group, store) for path_group, store in stores.items()
open_one(path_group, store)
for path_group, store in stores.items()
if path_group in filtered_set
]
results = await asyncio.gather(*tasks)
finally:
Expand Down
Loading