diff --git a/src/modelarrayio/cli/cifti_to_h5.py b/src/modelarrayio/cli/cifti_to_h5.py index 6fec34b..d316e08 100644 --- a/src/modelarrayio/cli/cifti_to_h5.py +++ b/src/modelarrayio/cli/cifti_to_h5.py @@ -13,14 +13,13 @@ from tqdm import tqdm from modelarrayio.cli import utils as cli_utils -from modelarrayio.cli.parser_utils import add_scalar_columns_arg, add_to_modelarray_args +from modelarrayio.cli.parser_utils import add_to_modelarray_args from modelarrayio.utils.cifti import ( - _build_scalar_sources, - _cohort_to_long_dataframe, - _load_cohort_cifti, brain_names_to_dataframe, extract_cifti_scalar_data, + load_cohort_cifti, ) +from modelarrayio.utils.misc import build_scalar_sources, cohort_to_long_dataframe logger = logging.getLogger(__name__) @@ -47,7 +46,7 @@ def cifti_to_h5( Path to a csv with demographic info and paths to data backend : :obj:`str` Backend to use for storage (``'hdf5'`` or ``'tiledb'``) - output : :obj:`str` + output : :obj:`pathlib.Path` Output path. For the hdf5 backend, path to an .h5 file; for the tiledb backend, path to a .tdb directory. storage_dtype : :obj:`str` @@ -77,19 +76,18 @@ def cifti_to_h5( 0 if successful, 1 if failed. """ cohort_df = pd.read_csv(cohort_file) - cohort_long = _cohort_to_long_dataframe(cohort_df, scalar_columns=scalar_columns) - output_path = Path(output) + cohort_long = cohort_to_long_dataframe(cohort_df, scalar_columns=scalar_columns) if cohort_long.empty: raise ValueError('Cohort file does not contain any scalar entries after normalization.') - scalar_sources = _build_scalar_sources(cohort_long) + scalar_sources = build_scalar_sources(cohort_long) if not scalar_sources: raise ValueError('Unable to derive scalar sources from cohort file.') if backend == 'hdf5': - scalars, last_brain_names = _load_cohort_cifti(cohort_long, s3_workers) + scalars, last_brain_names = load_cohort_cifti(cohort_long, s3_workers) greyordinate_table, structure_names = brain_names_to_dataframe(last_brain_names) - output_path = cli_utils.prepare_output_parent(output_path) - with h5py.File(output_path, 'w') as h5_file: + output = cli_utils.prepare_output_parent(output) + with h5py.File(output, 'w') as h5_file: cli_utils.write_table_dataset( h5_file, 'greyordinates', @@ -107,9 +105,9 @@ def cifti_to_h5( chunk_voxels=chunk_voxels, target_chunk_mb=target_chunk_mb, ) - return int(not output_path.exists()) + return int(not output.exists()) - output_path.mkdir(parents=True, exist_ok=True) + output.mkdir(parents=True, exist_ok=True) if not scalar_sources: return 0 @@ -127,7 +125,7 @@ def _process_scalar_job(scalar_name, source_files): if rows: cli_utils.write_tiledb_scalar_matrices( - output_path, + output, {scalar_name: rows}, {scalar_name: source_files}, storage_dtype=storage_dtype, @@ -178,5 +176,4 @@ def _parse_cifti_to_h5(): formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) add_to_modelarray_args(parser, default_output='greyordinatearray.h5') - add_scalar_columns_arg(parser) return parser diff --git a/src/modelarrayio/cli/h5_to_mif.py b/src/modelarrayio/cli/h5_to_mif.py index bb6bfed..99e0d3b 100644 --- a/src/modelarrayio/cli/h5_to_mif.py +++ b/src/modelarrayio/cli/h5_to_mif.py @@ -14,7 +14,7 @@ from modelarrayio.cli import utils as cli_utils from modelarrayio.cli.parser_utils import _is_file, add_from_modelarray_args, add_log_level_arg -from modelarrayio.utils.fixels import mif_to_nifti2, nifti2_to_mif +from modelarrayio.utils.mif import mif_to_nifti2, nifti2_to_mif logger = logging.getLogger(__name__) diff --git a/src/modelarrayio/cli/mif_to_h5.py b/src/modelarrayio/cli/mif_to_h5.py index 38ed1cc..326cb87 100644 --- a/src/modelarrayio/cli/mif_to_h5.py +++ b/src/modelarrayio/cli/mif_to_h5.py @@ -4,7 +4,8 @@ import argparse import logging -from collections import defaultdict +import os +from concurrent.futures import ThreadPoolExecutor, as_completed from functools import partial from pathlib import Path @@ -14,7 +15,8 @@ from modelarrayio.cli import utils as cli_utils from modelarrayio.cli.parser_utils import _is_file, add_to_modelarray_args -from modelarrayio.utils.fixels import gather_fixels, mif_to_nifti2 +from modelarrayio.utils.mif import gather_fixels, load_cohort_mif +from modelarrayio.utils.misc import cohort_to_long_dataframe logger = logging.getLogger(__name__) @@ -33,6 +35,7 @@ def mif_to_h5( target_chunk_mb=2.0, workers=None, s3_workers=1, + scalar_columns=None, ): """Load all fixeldb data and write to an HDF5 or TileDB file. @@ -75,25 +78,20 @@ def mif_to_h5( """ # gather fixel data fixel_table, voxel_table = gather_fixels(index_file, directions_file) - output_path = Path(output) - # gather cohort data cohort_df = pd.read_csv(cohort_file) + cohort_long = cohort_to_long_dataframe(cohort_df, scalar_columns=scalar_columns) + if cohort_long.empty: + raise ValueError('Cohort file does not contain any scalar entries after normalization.') - # upload each cohort's data - scalars = defaultdict(list) - sources_lists = defaultdict(list) logger.info('Extracting .mif data...') - for row in tqdm(cohort_df.itertuples(index=False), total=cohort_df.shape[0]): - scalar_file = row.source_file - _scalar_img, scalar_data = mif_to_nifti2(scalar_file) - scalars[row.scalar_name].append(scalar_data) - sources_lists[row.scalar_name].append(row.source_file) + scalars, sources_lists = load_cohort_mif(cohort_long, s3_workers) + if not sources_lists: + raise ValueError('Unable to derive scalar sources from cohort file.') - # Write the output if backend == 'hdf5': - output_path = cli_utils.prepare_output_parent(output_path) - with h5py.File(output_path, 'w') as h5_file: + output = cli_utils.prepare_output_parent(output) + with h5py.File(output, 'w') as h5_file: cli_utils.write_table_dataset(h5_file, 'fixels', fixel_table) cli_utils.write_table_dataset(h5_file, 'voxels', voxel_table) cli_utils.write_hdf5_scalar_matrices( @@ -107,19 +105,42 @@ def mif_to_h5( chunk_voxels=chunk_voxels, target_chunk_mb=target_chunk_mb, ) - return int(not output_path.exists()) - - cli_utils.write_tiledb_scalar_matrices( - output_path, - scalars, - sources_lists, - storage_dtype=storage_dtype, - compression=compression, - compression_level=compression_level, - shuffle=shuffle, - chunk_voxels=chunk_voxels, - target_chunk_mb=target_chunk_mb, - ) + return int(not output.exists()) + + output.mkdir(parents=True, exist_ok=True) + + scalar_names = list(sources_lists.keys()) + worker_count = workers if isinstance(workers, int) and workers > 0 else None + if worker_count is None: + cpu_count = os.cpu_count() or 1 + worker_count = min(len(scalar_names), max(1, cpu_count)) + else: + worker_count = min(len(scalar_names), worker_count) + + def _write_scalar_job(scalar_name): + cli_utils.write_tiledb_scalar_matrices( + output, + {scalar_name: scalars[scalar_name]}, + {scalar_name: sources_lists[scalar_name]}, + storage_dtype=storage_dtype, + compression=compression, + compression_level=compression_level, + shuffle=shuffle, + chunk_voxels=chunk_voxels, + target_chunk_mb=target_chunk_mb, + ) + + if worker_count <= 1: + for scalar_name in scalar_names: + _write_scalar_job(scalar_name) + else: + with ThreadPoolExecutor(max_workers=worker_count) as executor: + futures = { + executor.submit(_write_scalar_job, scalar_name): scalar_name + for scalar_name in scalar_names + } + for future in tqdm(as_completed(futures), total=len(futures), desc='TileDB scalars'): + future.result() return 0 diff --git a/src/modelarrayio/cli/nifti_to_h5.py b/src/modelarrayio/cli/nifti_to_h5.py index 28e2422..34a12f5 100644 --- a/src/modelarrayio/cli/nifti_to_h5.py +++ b/src/modelarrayio/cli/nifti_to_h5.py @@ -4,6 +4,8 @@ import argparse import logging +import os +from concurrent.futures import ThreadPoolExecutor, as_completed from functools import partial from pathlib import Path @@ -11,10 +13,12 @@ import nibabel as nb import numpy as np import pandas as pd +from tqdm import tqdm from modelarrayio.cli import utils as cli_utils from modelarrayio.cli.parser_utils import _is_file, add_to_modelarray_args -from modelarrayio.utils.voxels import _load_cohort_voxels +from modelarrayio.utils.misc import cohort_to_long_dataframe +from modelarrayio.utils.nifti import load_cohort_voxels logger = logging.getLogger(__name__) @@ -32,6 +36,7 @@ def nifti_to_h5( target_chunk_mb=2.0, workers=None, s3_workers=1, + scalar_columns=None, ): """Load all volume data and write to an HDF5 or TileDB file. @@ -43,7 +48,7 @@ def nifti_to_h5( Path to a CSV with demographic info and paths to data. backend : :obj:`str` Storage backend (``'hdf5'`` or ``'tiledb'``). - output : :obj:`str` + output : :obj:`pathlib.Path` Output path. For the hdf5 backend, path to an .h5 file; for the tiledb backend, path to a .tdb directory. storage_dtype : :obj:`str` @@ -65,13 +70,14 @@ def nifti_to_h5( s3_workers : :obj:`int` Number of parallel workers for S3 downloads. Default 1. """ - cohort_df = pd.read_csv(cohort_file) - output_path = Path(output) - group_mask_img = nb.load(group_mask_file) group_mask_matrix = group_mask_img.get_fdata() > 0 voxel_coords = np.column_stack(np.nonzero(group_mask_matrix)) + cohort_df = pd.read_csv(cohort_file) + cohort_long = cohort_to_long_dataframe(cohort_df, scalar_columns=scalar_columns) + if cohort_long.empty: + raise ValueError('Cohort file does not contain any scalar entries after normalization.') voxel_table = pd.DataFrame( { 'voxel_id': np.arange(voxel_coords.shape[0]), @@ -82,11 +88,13 @@ def nifti_to_h5( ) logger.info('Extracting NIfTI data...') - scalars, sources_lists = _load_cohort_voxels(cohort_df, group_mask_matrix, s3_workers) + scalars, sources_lists = load_cohort_voxels(cohort_long, group_mask_matrix, s3_workers) + if not sources_lists: + raise ValueError('Unable to derive scalar sources from cohort file.') if backend == 'hdf5': - output_path = cli_utils.prepare_output_parent(output_path) - with h5py.File(output_path, 'w') as h5_file: + output = cli_utils.prepare_output_parent(output) + with h5py.File(output, 'w') as h5_file: cli_utils.write_table_dataset(h5_file, 'voxels', voxel_table) cli_utils.write_hdf5_scalar_matrices( h5_file, @@ -99,19 +107,42 @@ def nifti_to_h5( chunk_voxels=chunk_voxels, target_chunk_mb=target_chunk_mb, ) - return int(not output_path.exists()) - - cli_utils.write_tiledb_scalar_matrices( - output_path, - scalars, - sources_lists, - storage_dtype=storage_dtype, - compression=compression, - compression_level=compression_level, - shuffle=shuffle, - chunk_voxels=chunk_voxels, - target_chunk_mb=target_chunk_mb, - ) + return int(not output.exists()) + + output.mkdir(parents=True, exist_ok=True) + + scalar_names = list(sources_lists.keys()) + worker_count = workers if isinstance(workers, int) and workers > 0 else None + if worker_count is None: + cpu_count = os.cpu_count() or 1 + worker_count = min(len(scalar_names), max(1, cpu_count)) + else: + worker_count = min(len(scalar_names), worker_count) + + def _write_scalar_job(scalar_name): + cli_utils.write_tiledb_scalar_matrices( + output, + {scalar_name: scalars[scalar_name]}, + {scalar_name: sources_lists[scalar_name]}, + storage_dtype=storage_dtype, + compression=compression, + compression_level=compression_level, + shuffle=shuffle, + chunk_voxels=chunk_voxels, + target_chunk_mb=target_chunk_mb, + ) + + if worker_count <= 1: + for scalar_name in scalar_names: + _write_scalar_job(scalar_name) + else: + with ThreadPoolExecutor(max_workers=worker_count) as executor: + futures = { + executor.submit(_write_scalar_job, scalar_name): scalar_name + for scalar_name in scalar_names + } + for future in tqdm(as_completed(futures), total=len(futures), desc='TileDB scalars'): + future.result() return 0 diff --git a/src/modelarrayio/cli/parser_utils.py b/src/modelarrayio/cli/parser_utils.py index 55c2c4a..a083699 100644 --- a/src/modelarrayio/cli/parser_utils.py +++ b/src/modelarrayio/cli/parser_utils.py @@ -20,6 +20,16 @@ def add_to_modelarray_args(parser, default_output='output.h5'): 'for the tiledb backend, path to a .tdb directory.' ), default=default_output, + type=Path, + ) + parser.add_argument( + '--scalar-columns', + '--scalar_columns', + nargs='+', + help=( + 'Column names containing scalar file paths when the cohort table is in wide format. ' + 'If omitted, the cohort file must include "scalar_name" and "source_file" columns.' + ), ) parser.add_argument( '--backend', @@ -110,19 +120,6 @@ def add_to_modelarray_args(parser, default_output='output.h5'): return parser -def add_scalar_columns_arg(parser): - parser.add_argument( - '--scalar-columns', - '--scalar_columns', - nargs='+', - help=( - 'Column names containing scalar file paths when the cohort table is in wide format. ' - "If omitted, the cohort file must include 'scalar_name' and 'source_file' columns." - ), - ) - return parser - - def add_log_level_arg(parser): parser.add_argument( '--log-level', diff --git a/src/modelarrayio/storage/h5_storage.py b/src/modelarrayio/storage/h5_storage.py index fedf62a..7c08e1c 100644 --- a/src/modelarrayio/storage/h5_storage.py +++ b/src/modelarrayio/storage/h5_storage.py @@ -16,10 +16,38 @@ def resolve_dtype(storage_dtype): + """Resolve a storage dtype to a supported NumPy floating type. + + Parameters + ---------- + storage_dtype : :obj:`str` + Storage dtype. + + Returns + ------- + :obj:`numpy.dtype` + Supported NumPy floating type. + """ return storage_utils.resolve_dtype(storage_dtype) def resolve_compression(compression, compression_level, shuffle): + """Resolve a compression method to a supported compression method. + + Parameters + ---------- + compression : :obj:`str` + Compression method. + compression_level : :obj:`int` + Compression level. + shuffle : :obj:`bool` + Whether to shuffle the data. + + Returns + ------- + :obj:`tuple` + Compression method, compression level, and whether to shuffle the data. + """ comp = ( None if compression is None or str(compression).lower() == 'none' @@ -39,6 +67,26 @@ def resolve_compression(compression, compression_level, shuffle): def compute_chunk_shape_full_subjects( num_subjects, num_items, item_chunk, target_chunk_mb, storage_np_dtype ): + """Compute a chunk shape for a full subject. + + Parameters + ---------- + num_subjects : :obj:`int` + Number of subjects. + num_items : :obj:`int` + Number of items. + item_chunk : :obj:`int` + Item chunk. + target_chunk_mb : :obj:`float` + Target chunk size in MB. + storage_np_dtype : :obj:`numpy.dtype` + Storage numpy dtype. + + Returns + ------- + :obj:`tuple` + Chunk shape. + """ chunk = storage_utils.compute_full_subject_chunk_shape( num_subjects=num_subjects, num_items=num_items, @@ -69,6 +117,36 @@ def create_scalar_matrix_dataset( chunk_voxels=0, target_chunk_mb=2.0, ): + """Create a scalar matrix dataset in an HDF5 file. + + Parameters + ---------- + h5file : :obj:`h5py.File` + HDF5 file. + dataset_path : :obj:`str` + Dataset path. + stacked_values : :obj:`numpy.ndarray` + Stacked values. + sources_list : :obj:`list` + Sources list. + storage_dtype : :obj:`str` + Storage dtype. + compression : :obj:`str` + Compression method. + compression_level : :obj:`int` + Compression level. + shuffle : :obj:`bool` + Whether to shuffle the data. + chunk_voxels : :obj:`int` + Chunk voxels. + target_chunk_mb : :obj:`float` + Target chunk size in MB. + + Returns + ------- + :obj:`h5py.Dataset` + Scalar matrix dataset. + """ storage_np_dtype = resolve_dtype(storage_dtype) comp, comp_opts, use_shuffle = resolve_compression(compression, compression_level, shuffle) @@ -118,6 +196,38 @@ def create_empty_scalar_matrix_dataset( target_chunk_mb=2.0, sources_list=None | pd.Series | list, ): + """Create an empty scalar matrix dataset in an HDF5 file. + + Parameters + ---------- + h5file : :obj:`h5py.File` + HDF5 file. + dataset_path : :obj:`str` + Dataset path. + num_subjects : :obj:`int` + Number of subjects. + num_items : :obj:`int` + Number of items. + storage_dtype : :obj:`str` + Storage dtype. + compression : :obj:`str` + Compression method. + compression_level : :obj:`int` + Compression level. + shuffle : :obj:`bool` + Whether to shuffle the data. + chunk_voxels : :obj:`int` + Chunk voxels. + target_chunk_mb : :obj:`float` + Target chunk size in MB. + sources_list : :obj:`list` + Sources list. + + Returns + ------- + :obj:`h5py.Dataset` + Empty scalar matrix dataset. + """ storage_np_dtype = resolve_dtype(storage_dtype) comp, comp_opts, use_shuffle = resolve_compression(compression, compression_level, shuffle) @@ -150,6 +260,17 @@ def create_empty_scalar_matrix_dataset( def write_column_names(h5_file: h5py.File, scalar: str, sources: pd.Series | list): + """Write column names to an HDF5 file. + + Parameters + ---------- + h5_file : :obj:`h5py.File` + HDF5 file. + scalar : :obj:`str` + Scalar name. + sources : :obj:`list` + Sources list. + """ values = np.array(storage_utils.normalize_column_names(sources), dtype=object) grp = h5_file.require_group(f'scalars/{scalar}') diff --git a/src/modelarrayio/storage/tiledb_storage.py b/src/modelarrayio/storage/tiledb_storage.py index df39374..33be392 100644 --- a/src/modelarrayio/storage/tiledb_storage.py +++ b/src/modelarrayio/storage/tiledb_storage.py @@ -16,6 +16,18 @@ def resolve_dtype(storage_dtype): + """Resolve a storage dtype to a supported NumPy floating type. + + Parameters + ---------- + storage_dtype : :obj:`str` + Storage dtype. + + Returns + ------- + :obj:`numpy.dtype` + Supported NumPy floating type. + """ return storage_utils.resolve_dtype(storage_dtype) @@ -46,6 +58,26 @@ def _build_filter_list(compression: str | None, compression_level: int | None, s def compute_tile_shape_full_subjects( num_subjects, num_items, item_tile, target_tile_mb, storage_np_dtype ): + """Compute a tile shape for a full subject. + + Parameters + ---------- + num_subjects : :obj:`int` + Number of subjects. + num_items : :obj:`int` + Number of items. + item_tile : :obj:`int` + Item tile. + target_tile_mb : :obj:`float` + Target tile size in MB. + storage_np_dtype : :obj:`numpy.dtype` + Storage numpy dtype. + + Returns + ------- + :obj:`tuple` + Tile shape. + """ tile = storage_utils.compute_full_subject_chunk_shape( num_subjects=num_subjects, num_items=num_items, @@ -82,6 +114,36 @@ def create_scalar_matrix_array( tile_voxels=0, target_tile_mb=2.0, ): + """Create a scalar matrix array in a TileDB directory. + + Parameters + ---------- + base_uri : :obj:`str` + Base URI. + dataset_path : :obj:`str` + Dataset path. + stacked_values : :obj:`numpy.ndarray` + Stacked values. + sources_list : :obj:`list` + Sources list. + storage_dtype : :obj:`str` + Storage dtype. + compression : :obj:`str` + Compression method. + compression_level : :obj:`int` + Compression level. + shuffle : :obj:`bool` + Whether to shuffle the data. + tile_voxels : :obj:`int` + Tile voxels. + target_tile_mb : :obj:`float` + Target tile size in MB. + + Returns + ------- + :obj:`str` + URI of the created array. + """ storage_np_dtype = resolve_dtype(storage_dtype) if stacked_values.dtype != storage_np_dtype: stacked_values = stacked_values.astype(storage_np_dtype) @@ -144,6 +206,38 @@ def create_empty_scalar_matrix_array( target_tile_mb=2.0, sources_list: Sequence[str] | None = None, ): + """Create an empty scalar matrix array in a TileDB directory. + + Parameters + ---------- + base_uri : :obj:`str` + Base URI. + dataset_path : :obj:`str` + Dataset path. + num_subjects : :obj:`int` + Number of subjects. + num_items : :obj:`int` + Number of items. + storage_dtype : :obj:`str` + Storage dtype. + compression : :obj:`str` + Compression method. + compression_level : :obj:`int` + Compression level. + shuffle : :obj:`bool` + Whether to shuffle the data. + tile_voxels : :obj:`int` + Tile voxels. + target_tile_mb : :obj:`float` + Target tile size in MB. + sources_list : :obj:`list` + Sources list. + + Returns + ------- + :obj:`str` + URI of the created array. + """ storage_np_dtype = resolve_dtype(storage_dtype) tile_shape = compute_tile_shape_full_subjects( num_subjects, num_items, tile_voxels, target_tile_mb, storage_np_dtype @@ -185,8 +279,7 @@ def create_empty_scalar_matrix_array( def write_rows_in_column_stripes(uri: str, rows: Sequence[np.ndarray]): - """ - Fill a 2D TileDB dense array by buffering column-aligned stripes to minimize + """Fill a 2D TileDB dense array by buffering column-aligned stripes to minimize tile writes, using about one tile's worth of memory. Parameters @@ -226,9 +319,16 @@ def write_rows_in_column_stripes(uri: str, rows: Sequence[np.ndarray]): def write_column_names(base_uri: str, scalar: str, sources: Sequence[str]): - """ - Store column names as a 1D dense TileDB array for the given scalar. - This mirrors the HDF5 dataset approach and scales to large cohorts. + """Store column names as a 1D dense TileDB array for the given scalar. + + Parameters + ---------- + base_uri : :obj:`str` + Base URI. + scalar : :obj:`str` + Scalar name. + sources : :obj:`list` + Sources list. """ sources = storage_utils.normalize_column_names(sources) uri = os.path.join(base_uri, 'scalars', scalar, 'column_names') @@ -239,7 +339,7 @@ def write_column_names(base_uri: str, scalar: str, sources: Sequence[str]): name='idx', domain=(0, max(n - 1, 0)), tile=max(1, min(n, 1024)), dtype=np.int64 ) dom = tiledb.Domain(dim_idx) - attr_values = tiledb.Attr(name='values', dtype=np.unicode_) + attr_values = tiledb.Attr(name='values', dtype=np.str_) schema = tiledb.ArraySchema(domain=dom, attrs=[attr_values], sparse=False) if tiledb.object_type(uri): diff --git a/src/modelarrayio/utils/cifti.py b/src/modelarrayio/utils/cifti.py index 6dfdfe8..b3b9a66 100644 --- a/src/modelarrayio/utils/cifti.py +++ b/src/modelarrayio/utils/cifti.py @@ -2,7 +2,7 @@ from __future__ import annotations -from collections import OrderedDict, defaultdict +from collections import defaultdict from concurrent.futures import ThreadPoolExecutor, as_completed from pathlib import Path @@ -14,51 +14,6 @@ from modelarrayio.utils.s3_utils import load_nibabel -def _cohort_to_long_dataframe(cohort_df, scalar_columns=None): - scalar_columns = [col for col in (scalar_columns or []) if col] - if scalar_columns: - missing = [col for col in scalar_columns if col not in cohort_df.columns] - if missing: - raise ValueError(f'Wide-format cohort is missing scalar columns: {missing}') - records = [] - selected_columns = cohort_df[scalar_columns] - for row_values in selected_columns.itertuples(index=False, name=None): - for scalar_col, source_val in zip(scalar_columns, row_values, strict=True): - if pd.isna(source_val) or source_val is None: - continue - source_str = str(source_val).strip() - if not source_str: - continue - records.append({'scalar_name': scalar_col, 'source_file': source_str}) - return pd.DataFrame.from_records(records, columns=['scalar_name', 'source_file']) - - required = {'scalar_name', 'source_file'} - missing = required - set(cohort_df.columns) - if missing: - raise ValueError( - f'Cohort file must contain columns {sorted(required)} when ' - '--scalar-columns is not used.' - ) - - long_df = cohort_df[list(required)].copy() - long_df = long_df.dropna(subset=['scalar_name', 'source_file']) - long_df['scalar_name'] = long_df['scalar_name'].astype(str).str.strip() - long_df['source_file'] = long_df['source_file'].astype(str).str.strip() - long_df = long_df[(long_df['scalar_name'] != '') & (long_df['source_file'] != '')] - return long_df.reset_index(drop=True) - - -def _build_scalar_sources(long_df): - scalar_sources = OrderedDict() - for row in long_df.itertuples(index=False): - scalar = str(row.scalar_name) - source = str(row.source_file) - if not scalar or not source: - continue - scalar_sources.setdefault(scalar, []).append(source) - return scalar_sources - - def extract_cifti_scalar_data(cifti_file, reference_brain_names=None): """Load a scalar cifti file and get its data and mapping @@ -137,7 +92,7 @@ def brain_names_to_dataframe(brain_names): return greyordinate_df, structure_name_strings -def _load_cohort_cifti(cohort_long, s3_workers): +def load_cohort_cifti(cohort_long, s3_workers): """Load all CIFTI scalar rows from the cohort, optionally in parallel. The first file is always loaded serially to obtain the reference brain diff --git a/src/modelarrayio/utils/fixels.py b/src/modelarrayio/utils/mif.py similarity index 67% rename from src/modelarrayio/utils/fixels.py rename to src/modelarrayio/utils/mif.py index 38e650e..aa9d913 100644 --- a/src/modelarrayio/utils/fixels.py +++ b/src/modelarrayio/utils/mif.py @@ -1,16 +1,26 @@ -"""Utility functions for fixel-wise data.""" +"""Utility functions for MIF data.""" import shutil import subprocess import tempfile +from collections import defaultdict +from concurrent.futures import ThreadPoolExecutor, as_completed from pathlib import Path import nibabel as nb import numpy as np import pandas as pd +from tqdm import tqdm def find_mrconvert(): + """Find the mrconvert executable on the system. + + Returns + ------- + :obj:`str` + Path to the mrconvert executable. + """ return shutil.which('mrconvert') @@ -91,6 +101,69 @@ def mif_to_nifti2(mif_file): return nifti2_img, data +def load_cohort_mif(cohort_long, s3_workers): + """Load all MIF scalar rows from the cohort, optionally in parallel. + + When s3_workers > 1, a ThreadPoolExecutor is used to run mrconvert + calls concurrently (subprocess calls release the GIL). Results arrive + via as_completed and are indexed by (scalar_name, subj_idx) so the + final ordered lists are reconstructed correctly regardless of completion + order. + + Parameters + ---------- + cohort_long : :obj:`pandas.DataFrame` + Long-format cohort dataframe with columns 'scalar_name' and 'source_file'. + s3_workers : :obj:`int` + Number of parallel workers for loading. + + Returns + ------- + scalars : dict[str, list[np.ndarray]] + Per-scalar ordered list of 1-D subject arrays, ready for stripe-write. + sources_lists : dict[str, list[str]] + Per-scalar ordered list of source file paths (for HDF5 metadata). + """ + scalar_subj_counter = defaultdict(int) + jobs = [] + sources_lists = defaultdict(list) + + for row in cohort_long.itertuples(index=False): + sn = row.scalar_name + subj_idx = scalar_subj_counter[sn] + scalar_subj_counter[sn] += 1 + src = row.source_file + jobs.append((sn, subj_idx, src)) + sources_lists[sn].append(src) + + def _worker(job): + sn, subj_idx, src = job + _img, data = mif_to_nifti2(src) + return sn, subj_idx, data + + if s3_workers > 1: + results = defaultdict(dict) + with ThreadPoolExecutor(max_workers=s3_workers) as pool: + futures = {pool.submit(_worker, job): job for job in jobs} + for future in tqdm( + as_completed(futures), + total=len(futures), + desc='Loading MIF data', + ): + sn, subj_idx, data = future.result() + results[sn][subj_idx] = data + scalars = { + sn: [results[sn][i] for i in range(cnt)] for sn, cnt in scalar_subj_counter.items() + } + else: + scalars = defaultdict(list) + for job in tqdm(jobs, desc='Loading MIF data'): + sn, subj_idx, data = _worker(job) + scalars[sn].append(data) + + return scalars, sources_lists + + def gather_fixels(index_file, directions_file): """Load the index and directions files to get lookup tables. diff --git a/src/modelarrayio/utils/misc.py b/src/modelarrayio/utils/misc.py new file mode 100644 index 0000000..f7fcd61 --- /dev/null +++ b/src/modelarrayio/utils/misc.py @@ -0,0 +1,86 @@ +"""Miscellaneous utility functions.""" + +from __future__ import annotations + +from collections import OrderedDict + +import pandas as pd + + +def cohort_to_long_dataframe(cohort_df, scalar_columns=None): + """Convert a wide-format cohort dataframe to a long-format dataframe. + + Parameters + ---------- + cohort_df : :obj:`pandas.DataFrame` + Wide-format cohort dataframe + scalar_columns : :obj:`list` + List of scalar columns to use. If provided, these columns are treated as + file-path columns and melted into 'scalar_name'/'source_file' rows. All + remaining columns (e.g. 'source_mask_file') are broadcast to every output + row. If not provided, the dataframe is treated as already long-format. + + Returns + ------- + long_df : :obj:`pandas.DataFrame` + Long-format cohort dataframe with columns 'scalar_name', 'source_file', + and any non-scalar columns from the input. + """ + scalar_columns = [col for col in (scalar_columns or []) if col] + if scalar_columns: + missing = [col for col in scalar_columns if col not in cohort_df.columns] + if missing: + raise ValueError(f'Wide-format cohort is missing scalar columns: {missing}') + extra_columns = [col for col in cohort_df.columns if col not in scalar_columns] + records = [] + for _, row in cohort_df.iterrows(): + extra = {col: row[col] for col in extra_columns} + for scalar_col in scalar_columns: + source_val = row[scalar_col] + if pd.isna(source_val) or source_val is None: + continue + source_str = str(source_val).strip() + if not source_str: + continue + records.append({'scalar_name': scalar_col, 'source_file': source_str, **extra}) + output_columns = ['scalar_name', 'source_file'] + extra_columns + return pd.DataFrame.from_records(records, columns=output_columns) + + required = {'scalar_name', 'source_file'} + missing = required - set(cohort_df.columns) + if missing: + raise ValueError( + f'Cohort file must contain columns {sorted(required)} when ' + '--scalar-columns is not used.' + ) + + long_df = cohort_df.copy() + long_df = long_df.dropna(subset=['scalar_name', 'source_file']) + long_df['scalar_name'] = long_df['scalar_name'].astype(str).str.strip() + long_df['source_file'] = long_df['source_file'].astype(str).str.strip() + long_df = long_df[(long_df['scalar_name'] != '') & (long_df['source_file'] != '')] + return long_df.reset_index(drop=True) + + +def build_scalar_sources(long_df): + """Build a dictionary of scalar sources from a long dataframe. + + Parameters + ---------- + long_df : :obj:`pandas.DataFrame` + Long-format cohort dataframe with columns 'scalar_name' and 'source_file'. + + Returns + ------- + scalar_sources : :obj:`OrderedDict` + Dictionary of scalar sources. + Keys are scalar names, values are lists of source files. + """ + scalar_sources = OrderedDict() + for row in long_df.itertuples(index=False): + scalar = str(row.scalar_name) + source = str(row.source_file) + if not scalar or not source: + continue + scalar_sources.setdefault(scalar, []).append(source) + return scalar_sources diff --git a/src/modelarrayio/utils/voxels.py b/src/modelarrayio/utils/nifti.py similarity index 79% rename from src/modelarrayio/utils/voxels.py rename to src/modelarrayio/utils/nifti.py index 0912b84..9e18ff7 100644 --- a/src/modelarrayio/utils/voxels.py +++ b/src/modelarrayio/utils/nifti.py @@ -1,4 +1,4 @@ -"""Utility functions for voxel-wise data.""" +"""Utility functions for NIfTI data.""" from __future__ import annotations @@ -13,7 +13,7 @@ from modelarrayio.utils.s3_utils import load_nibabel -def _load_cohort_voxels(cohort_df, group_mask_matrix, s3_workers): +def load_cohort_voxels(cohort_long, group_mask_matrix, s3_workers): """Load all voxel rows from the cohort, optionally in parallel. When s3_workers > 1, a ThreadPoolExecutor is used. Threads share memory so @@ -21,6 +21,16 @@ def _load_cohort_voxels(cohort_df, group_mask_matrix, s3_workers): arrive via as_completed and are indexed by (scalar_name, subj_idx) so the final ordered lists are reconstructed correctly regardless of completion order. + Parameters + ---------- + cohort_long : :obj:`pandas.DataFrame` + Long-format cohort dataframe with columns 'scalar_name', 'source_file', + and 'source_mask_file'. + group_mask_matrix : :obj:`numpy.ndarray` + Boolean group mask array. + s3_workers : :obj:`int` + Number of parallel workers for loading. + Returns ------- scalars : dict[str, list[np.ndarray]] @@ -32,7 +42,7 @@ def _load_cohort_voxels(cohort_df, group_mask_matrix, s3_workers): jobs = [] sources_lists = defaultdict(list) - for row in cohort_df.itertuples(index=False): + for row in cohort_long.itertuples(index=False): sn = row.scalar_name subj_idx = scalar_subj_counter[sn] scalar_subj_counter[sn] += 1 @@ -72,6 +82,22 @@ def _worker(job): def flattened_image(scalar_image, scalar_mask, group_mask_matrix): + """Flatten a scalar image to a 1-D array. + + Parameters + ---------- + scalar_image : :obj:`nibabel.Nifti1Image` + Scalar image. + scalar_mask : :obj:`nibabel.Nifti1Image` + Scalar mask. + group_mask_matrix : :obj:`numpy.ndarray` + Group mask matrix. + + Returns + ------- + :obj:`numpy.ndarray` + Flattened scalar image. + """ scalar_mask_img = ( scalar_mask if hasattr(scalar_mask, 'get_fdata') else nb.load(Path(scalar_mask)) ) diff --git a/test/test_cifti_cohort.py b/test/test_cifti_cohort.py index 8117387..eee8e00 100644 --- a/test/test_cifti_cohort.py +++ b/test/test_cifti_cohort.py @@ -6,11 +6,8 @@ import pandas as pd import pytest -from modelarrayio.utils.cifti import ( - _build_scalar_sources, - _cohort_to_long_dataframe, - brain_names_to_dataframe, -) +from modelarrayio.utils.cifti import brain_names_to_dataframe +from modelarrayio.utils.misc import build_scalar_sources, cohort_to_long_dataframe def test_cohort_long_format_preserves_rows() -> None: @@ -21,9 +18,9 @@ def test_cohort_long_format_preserves_rows() -> None: 'extra_col': [1, 2], } ) - long_df = _cohort_to_long_dataframe(df) + long_df = cohort_to_long_dataframe(df) assert len(long_df) == 2 - assert set(long_df.columns) == {'scalar_name', 'source_file'} + assert set(long_df.columns) == {'extra_col', 'scalar_name', 'source_file'} assert long_df.iloc[0]['scalar_name'] == 'THICK' @@ -34,7 +31,7 @@ def test_cohort_long_format_strips_and_drops_empty() -> None: 'source_file': [' a.nii ', ' b.nii '], } ) - long_df = _cohort_to_long_dataframe(df) + long_df = cohort_to_long_dataframe(df) assert len(long_df) == 1 assert long_df.iloc[0]['scalar_name'] == 'THICK' @@ -47,7 +44,7 @@ def test_cohort_wide_format_expands_columns() -> None: 'FA': ['f1.nii', ''], } ) - long_df = _cohort_to_long_dataframe(df, scalar_columns=['THICK', 'FA']) + long_df = cohort_to_long_dataframe(df, scalar_columns=['THICK', 'FA']) # Row 2 has empty FA — skipped assert len(long_df) == 3 scalars = set(long_df['scalar_name']) @@ -57,13 +54,13 @@ def test_cohort_wide_format_expands_columns() -> None: def test_cohort_wide_format_missing_scalar_column_raises() -> None: df = pd.DataFrame({'THICK': ['a.nii']}) with pytest.raises(ValueError, match='missing scalar columns'): - _cohort_to_long_dataframe(df, scalar_columns=['THICK', 'MISSING']) + cohort_to_long_dataframe(df, scalar_columns=['THICK', 'MISSING']) def test_cohort_long_missing_required_raises() -> None: df = pd.DataFrame({'only_this': [1]}) with pytest.raises(ValueError, match='scalar_name'): - _cohort_to_long_dataframe(df) + cohort_to_long_dataframe(df) def test_build_scalar_sources_ordering() -> None: @@ -73,7 +70,7 @@ def test_build_scalar_sources_ordering() -> None: 'source_file': ['x1', 'x2', 'y1'], } ) - src = _build_scalar_sources(long_df) + src = build_scalar_sources(long_df) assert list(src.keys()) == ['A', 'B'] assert src['A'] == ['x1', 'x2'] assert src['B'] == ['y1'] diff --git a/test/test_fixels_utils.py b/test/test_fixels_utils.py index 179f5bf..91bb270 100644 --- a/test/test_fixels_utils.py +++ b/test/test_fixels_utils.py @@ -6,7 +6,7 @@ import numpy as np import pytest -from modelarrayio.utils import fixels +from modelarrayio.utils import mif def _make_nifti2(shape=(2, 1, 1)) -> nb.Nifti2Image: @@ -15,14 +15,14 @@ def _make_nifti2(shape=(2, 1, 1)) -> nb.Nifti2Image: def test_nifti2_to_mif_raises_when_mrconvert_missing(tmp_path, monkeypatch) -> None: - monkeypatch.setattr(fixels, 'find_mrconvert', lambda: None) + monkeypatch.setattr(mif, 'find_mrconvert', lambda: None) with pytest.raises(FileNotFoundError, match='mrconvert'): - fixels.nifti2_to_mif(_make_nifti2(), tmp_path / 'out.mif') + mif.nifti2_to_mif(_make_nifti2(), tmp_path / 'out.mif') def test_mif_to_nifti2_raises_when_mrconvert_missing(monkeypatch) -> None: - monkeypatch.setattr(fixels, 'find_mrconvert', lambda: None) + monkeypatch.setattr(mif, 'find_mrconvert', lambda: None) with pytest.raises(FileNotFoundError, match='mrconvert'): - fixels.mif_to_nifti2('missing_input.mif') + mif.mif_to_nifti2('missing_input.mif') diff --git a/test/test_parser_utils.py b/test/test_parser_utils.py index 4aed56a..6fa13d4 100644 --- a/test/test_parser_utils.py +++ b/test/test_parser_utils.py @@ -3,6 +3,7 @@ from __future__ import annotations import argparse +from pathlib import Path from modelarrayio.cli import parser_utils @@ -66,7 +67,7 @@ def test_output_hdf5_default_name_override(tmp_path_factory) -> None: p = argparse.ArgumentParser() parser_utils.add_to_modelarray_args(p, default_output='custom.h5') args = p.parse_args(['--cohort-file', str(cohort_file)]) - assert args.output == 'custom.h5' + assert args.output == Path('custom.h5') def test_tiledb_args_group(tmp_path_factory) -> None: @@ -76,7 +77,7 @@ def test_tiledb_args_group(tmp_path_factory) -> None: p = argparse.ArgumentParser() parser_utils.add_to_modelarray_args(p, default_output='arrays.tdb') args = p.parse_args(['--cohort-file', str(cohort_file), '--backend', 'tiledb']) - assert args.output == 'arrays.tdb' + assert args.output == Path('arrays.tdb') assert args.backend == 'tiledb' assert args.workers == 0 assert args.s3_workers == 1 diff --git a/test/test_voxels_utils.py b/test/test_voxels_utils.py index 82d7c92..dc56625 100644 --- a/test/test_voxels_utils.py +++ b/test/test_voxels_utils.py @@ -6,7 +6,7 @@ import numpy as np import pytest -from modelarrayio.utils.voxels import flattened_image +from modelarrayio.utils.nifti import flattened_image def _eye_affine():