diff --git a/README.rst b/README.rst index e6f01a1..d5beafc 100644 --- a/README.rst +++ b/README.rst @@ -52,16 +52,19 @@ Once ModelArrayIO is installed, these commands are available in your terminal: * ``.mif`` → ``.h5``: ``modelarrayio mif-to-h5`` * ``.h5`` → ``.mif``: ``modelarrayio h5-to-mif`` + * ``.h5`` scalar row → ``.mif``: ``modelarrayio h5-export-mif-file`` * **Voxel-wise** data (NIfTI): * NIfTI → ``.h5``: ``modelarrayio nifti-to-h5`` * ``.h5`` → NIfTI: ``modelarrayio h5-to-nifti`` + * ``.h5`` scalar row → NIfTI: ``modelarrayio h5-export-nifti-file`` * **Greyordinate-wise** data (CIFTI-2): * CIFTI-2 → ``.h5``: ``modelarrayio cifti-to-h5`` * ``.h5`` → CIFTI-2: ``modelarrayio h5-to-cifti`` + * ``.h5`` scalar row → CIFTI-2: ``modelarrayio h5-export-cifti-file`` Storage backends: HDF5 and TileDB diff --git a/docs/usage.rst b/docs/usage.rst index 9a8e942..40e827a 100644 --- a/docs/usage.rst +++ b/docs/usage.rst @@ -3,9 +3,9 @@ Usage ##### -********** +********* mif-to-h5 -********** +********* .. argparse:: :ref: modelarrayio.cli.mif_to_h5._parse_mif_to_h5 @@ -13,9 +13,9 @@ mif-to-h5 :func: _parse_mif_to_h5 -********** +*********** nifti-to-h5 -********** +************ .. argparse:: :ref: modelarrayio.cli.nifti_to_h5._parse_nifti_to_h5 @@ -23,9 +23,9 @@ nifti-to-h5 :func: _parse_nifti_to_h5 -********** +*********** cifti-to-h5 -********** +*********** .. argparse:: :ref: modelarrayio.cli.cifti_to_h5._parse_cifti_to_h5 @@ -33,9 +33,9 @@ cifti-to-h5 :func: _parse_cifti_to_h5 -********** +********* h5-to-mif -********** +********* .. argparse:: :ref: modelarrayio.cli.h5_to_mif._parse_h5_to_mif @@ -60,3 +60,33 @@ h5-to-cifti :ref: modelarrayio.cli.h5_to_cifti._parse_h5_to_cifti :prog: modelarrayio h5-to-cifti :func: _parse_h5_to_cifti + + +****************** +h5-export-mif-file +****************** + +.. argparse:: + :ref: modelarrayio.cli.h5_export_mif_file._parse_h5_export_mif_file + :prog: modelarrayio h5-export-mif-file + :func: _parse_h5_export_mif_file + + +******************** +h5-export-nifti-file +******************** + +.. argparse:: + :ref: modelarrayio.cli.h5_export_nifti_file._parse_h5_export_nifti_file + :prog: modelarrayio h5-export-nifti-file + :func: _parse_h5_export_nifti_file + + +******************** +h5-export-cifti-file +******************** + +.. argparse:: + :ref: modelarrayio.cli.h5_export_cifti_file._parse_h5_export_cifti_file + :prog: modelarrayio h5-export-cifti-file + :func: _parse_h5_export_cifti_file diff --git a/modelarrayio/__about__.py b/modelarrayio/__about__.py new file mode 100644 index 0000000..f46b8e4 --- /dev/null +++ b/modelarrayio/__about__.py @@ -0,0 +1,34 @@ +# file generated by setuptools-scm +# don't change, don't track in version control + +__all__ = [ + "__version__", + "__version_tuple__", + "version", + "version_tuple", + "__commit_id__", + "commit_id", +] + +TYPE_CHECKING = False +if TYPE_CHECKING: + from typing import Tuple + from typing import Union + + VERSION_TUPLE = Tuple[Union[int, str], ...] + COMMIT_ID = Union[str, None] +else: + VERSION_TUPLE = object + COMMIT_ID = object + +version: str +__version__: str +__version_tuple__: VERSION_TUPLE +version_tuple: VERSION_TUPLE +commit_id: COMMIT_ID +__commit_id__: COMMIT_ID + +__version__ = version = '0.1.dev115+g7f4c6030e' +__version_tuple__ = version_tuple = (0, 1, 'dev115', 'g7f4c6030e') + +__commit_id__ = commit_id = None diff --git a/src/modelarrayio/cli/cifti_to_h5.py b/src/modelarrayio/cli/cifti_to_h5.py index 6fec34b..4a54106 100644 --- a/src/modelarrayio/cli/cifti_to_h5.py +++ b/src/modelarrayio/cli/cifti_to_h5.py @@ -12,8 +12,13 @@ import pandas as pd from tqdm import tqdm +from modelarrayio.cli import diagnostics as cli_diagnostics from modelarrayio.cli import utils as cli_utils -from modelarrayio.cli.parser_utils import add_scalar_columns_arg, add_to_modelarray_args +from modelarrayio.cli.parser_utils import ( + add_diagnostics_args, + add_scalar_columns_arg, + add_to_modelarray_args, +) from modelarrayio.utils.cifti import ( _build_scalar_sources, _cohort_to_long_dataframe, @@ -21,6 +26,7 @@ brain_names_to_dataframe, extract_cifti_scalar_data, ) +from modelarrayio.utils.s3_utils import load_nibabel logger = logging.getLogger(__name__) @@ -38,6 +44,9 @@ def cifti_to_h5( workers=None, s3_workers=1, scalar_columns=None, + no_diagnostics=False, + diagnostics_dir=None, + diagnostic_maps=None, ): """Load all CIFTI data and write to an HDF5 or TileDB file. @@ -70,6 +79,12 @@ def cifti_to_h5( Number of workers for parallel S3 downloads scalar_columns : :obj:`list` List of scalar columns to use + no_diagnostics : :obj:`bool` + Disable diagnostic outputs in native format. + diagnostics_dir : :obj:`str` or :obj:`None` + Output directory for diagnostics. Defaults to ``_diagnostics``. + diagnostic_maps : :obj:`list` or :obj:`None` + Diagnostic maps to write. Supported: ``mean``, ``element_id``, ``n_non_nan``. Returns ------- @@ -84,10 +99,35 @@ def cifti_to_h5( scalar_sources = _build_scalar_sources(cohort_long) if not scalar_sources: raise ValueError('Unable to derive scalar sources from cohort file.') + maps_to_write = cli_utils.normalize_diagnostic_maps(diagnostic_maps) + + _first_scalar, first_sources = next(iter(scalar_sources.items())) + first_path = first_sources[0] + template_cifti = load_nibabel(first_path, cifti=True) + _first_data, reference_brain_names = extract_cifti_scalar_data(template_cifti) + + if not no_diagnostics: + output_diag_dir = ( + Path(diagnostics_dir) + if diagnostics_dir is not None + else cli_utils.default_diagnostics_dir(output_path) + ) + output_diag_dir.mkdir(parents=True, exist_ok=True) + cli_diagnostics.verify_cifti_element_mapping(template_cifti, reference_brain_names) if backend == 'hdf5': scalars, last_brain_names = _load_cohort_cifti(cohort_long, s3_workers) greyordinate_table, structure_names = brain_names_to_dataframe(last_brain_names) + if not no_diagnostics: + for scalar_name, rows in scalars.items(): + diagnostics = cli_diagnostics.summarize_rows(rows) + cli_diagnostics.write_cifti_diagnostics( + maps=maps_to_write, + scalar_name=scalar_name, + diagnostics=diagnostics, + template_cifti=template_cifti, + output_dir=output_diag_dir, + ) output_path = cli_utils.prepare_output_parent(output_path) with h5py.File(output_path, 'w') as h5_file: cli_utils.write_table_dataset( @@ -113,10 +153,6 @@ def cifti_to_h5( if not scalar_sources: return 0 - _first_scalar, first_sources = next(iter(scalar_sources.items())) - first_path = first_sources[0] - _, reference_brain_names = extract_cifti_scalar_data(first_path) - def _process_scalar_job(scalar_name, source_files): rows = [] for source_file in source_files: @@ -126,6 +162,15 @@ def _process_scalar_job(scalar_name, source_files): rows.append(cifti_data) if rows: + if not no_diagnostics: + diagnostics = cli_diagnostics.summarize_rows(rows) + cli_diagnostics.write_cifti_diagnostics( + maps=maps_to_write, + scalar_name=scalar_name, + diagnostics=diagnostics, + template_cifti=template_cifti, + output_dir=output_diag_dir, + ) cli_utils.write_tiledb_scalar_matrices( output_path, {scalar_name: rows}, @@ -179,4 +224,5 @@ def _parse_cifti_to_h5(): ) add_to_modelarray_args(parser, default_output='greyordinatearray.h5') add_scalar_columns_arg(parser) + add_diagnostics_args(parser) return parser diff --git a/src/modelarrayio/cli/diagnostics.py b/src/modelarrayio/cli/diagnostics.py new file mode 100644 index 0000000..98d5c99 --- /dev/null +++ b/src/modelarrayio/cli/diagnostics.py @@ -0,0 +1,116 @@ +"""Diagnostic image helpers for conversion commands.""" + +from __future__ import annotations + +from pathlib import Path + +import nibabel as nb +import numpy as np + +from modelarrayio.utils.cifti import extract_cifti_scalar_data +from modelarrayio.utils.fixels import nifti2_to_mif +from modelarrayio.utils.voxels import flattened_image + + +def summarize_rows(rows) -> dict[str, np.ndarray]: + """Compute common diagnostics from a sequence of 1-D subject arrays.""" + stacked = np.vstack(rows) + return { + 'mean': np.nanmean(stacked, axis=0).astype(np.float32), + 'n_non_nan': np.sum(~np.isnan(stacked), axis=0).astype(np.float32), + 'element_id': np.arange(stacked.shape[1], dtype=np.float32), + } + + +def verify_nifti_element_mapping(group_mask_img, group_mask_matrix): + """Verify NIfTI group-mask flattening order matches element indices.""" + expected = np.arange(int(group_mask_matrix.sum()), dtype=np.float32) + element_volume = np.zeros(group_mask_matrix.shape, dtype=np.float32) + element_volume[group_mask_matrix] = expected + element_img = nb.Nifti1Image( + element_volume, + affine=group_mask_img.affine, + header=group_mask_img.header, + ) + extracted = flattened_image(element_img, group_mask_img, group_mask_matrix) + if not np.array_equal(extracted.astype(np.int64), expected.astype(np.int64)): + raise ValueError('Element ID mapping check failed for NIfTI group-mask flattening.') + + +def write_nifti_diagnostics( + *, + maps: list[str], + scalar_name: str, + diagnostics: dict[str, np.ndarray], + group_mask_img, + group_mask_matrix, + output_dir: Path, +): + header = group_mask_img.header.copy() + header.set_data_dtype(np.float32) + for name in maps: + out_file = output_dir / f'{scalar_name}_{name}.nii.gz' + data = np.zeros(group_mask_matrix.shape, dtype=np.float32) + data[group_mask_matrix] = diagnostics[name] + nb.Nifti1Image(data, affine=group_mask_img.affine, header=header).to_filename(out_file) + + +def verify_cifti_element_mapping(template_cifti, reference_brain_names): + """Verify CIFTI extraction order matches element indices.""" + expected = np.arange(reference_brain_names.shape[0], dtype=np.float32) + test_img = nb.Cifti2Image( + expected.reshape(1, -1), + header=template_cifti.header, + nifti_header=template_cifti.nifti_header, + ) + recovered, _ = extract_cifti_scalar_data(test_img, reference_brain_names=reference_brain_names) + if not np.array_equal(recovered.astype(np.int64), expected.astype(np.int64)): + raise ValueError('Element ID mapping check failed for CIFTI greyordinate ordering.') + + +def write_cifti_diagnostics( + *, + maps: list[str], + scalar_name: str, + diagnostics: dict[str, np.ndarray], + template_cifti, + output_dir: Path, +): + for name in maps: + out_file = output_dir / f'{scalar_name}_{name}.dscalar.nii' + nb.Cifti2Image( + diagnostics[name].reshape(1, -1), + header=template_cifti.header, + nifti_header=template_cifti.nifti_header, + ).to_filename(out_file) + + +def verify_mif_element_mapping(template_nifti2, num_elements: int): + """Verify fixel vector reshape/squeeze mapping remains identity.""" + expected = np.arange(num_elements, dtype=np.float32) + test_img = nb.Nifti2Image( + expected.reshape(-1, 1, 1), + affine=template_nifti2.affine, + header=template_nifti2.header, + ) + recovered = test_img.get_fdata(dtype=np.float32).squeeze() + if not np.array_equal(recovered.astype(np.int64), expected.astype(np.int64)): + raise ValueError('Element ID mapping check failed for MIF fixel vector ordering.') + + +def write_mif_diagnostics( + *, + maps: list[str], + scalar_name: str, + diagnostics: dict[str, np.ndarray], + template_nifti2, + output_dir: Path, +): + for name in maps: + out_file = output_dir / f'{scalar_name}_{name}.mif' + temp_nifti2 = nb.Nifti2Image( + diagnostics[name].reshape(-1, 1, 1), + affine=template_nifti2.affine, + header=template_nifti2.header, + ) + nifti2_to_mif(temp_nifti2, out_file) diff --git a/src/modelarrayio/cli/h5_export_cifti_file.py b/src/modelarrayio/cli/h5_export_cifti_file.py new file mode 100644 index 0000000..c8fc71d --- /dev/null +++ b/src/modelarrayio/cli/h5_export_cifti_file.py @@ -0,0 +1,87 @@ +"""Export one scalar matrix row from HDF5 to a CIFTI dscalar file.""" + +from __future__ import annotations + +import argparse +import logging +from functools import partial + +import nibabel as nb +import pandas as pd + +from modelarrayio.cli import utils as cli_utils +from modelarrayio.cli.parser_utils import ( + _is_file, + add_hdf5_scalar_export_args, + add_log_level_arg, +) + +logger = logging.getLogger(__name__) + + +def h5_export_cifti_file( + in_file, + scalar_name, + output_file, + column_index=None, + source_file=None, + cohort_file=None, + example_cifti=None, +): + row = cli_utils.load_hdf5_scalar_row( + in_file, + scalar_name, + column_index=column_index, + source_file=source_file, + ) + + if example_cifti is None: + cohort_df = pd.read_csv(cohort_file) + example_cifti = cohort_df['source_file'].iloc[0] + cifti = nb.load(example_cifti) + if row.shape[0] != cifti.shape[-1]: + raise ValueError( + f'Scalar row length ({row.shape[0]}) does not match CIFTI greyordinates ' + f'({cifti.shape[-1]}).' + ) + + out_path = cli_utils.prepare_output_parent(output_file) + nb.Cifti2Image( + row.reshape(1, -1), + header=cifti.header, + nifti_header=cifti.nifti_header, + ).to_filename(out_path) + + +def h5_export_cifti_file_main(**kwargs): + log_level = kwargs.pop('log_level', 'INFO') + cli_utils.configure_logging(log_level) + h5_export_cifti_file(**kwargs) + return 0 + + +def _parse_h5_export_cifti_file(): + parser = argparse.ArgumentParser( + description='Export one row from scalars//values to a CIFTI file', + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + is_file = partial(_is_file, parser=parser) + add_hdf5_scalar_export_args(parser) + + example_group = parser.add_mutually_exclusive_group(required=True) + example_group.add_argument( + '--cohort-file', + '--cohort_file', + help='Path to cohort CSV used to choose an example CIFTI file.', + type=is_file, + default=None, + ) + example_group.add_argument( + '--example-cifti', + '--example_cifti', + help='Path to an example CIFTI file used as output template.', + type=is_file, + default=None, + ) + add_log_level_arg(parser) + return parser diff --git a/src/modelarrayio/cli/h5_export_mif_file.py b/src/modelarrayio/cli/h5_export_mif_file.py new file mode 100644 index 0000000..74d6f25 --- /dev/null +++ b/src/modelarrayio/cli/h5_export_mif_file.py @@ -0,0 +1,68 @@ +"""Export one scalar matrix row from HDF5 to a MIF file.""" + +from __future__ import annotations + +import argparse +import logging +from functools import partial + +import nibabel as nb + +from modelarrayio.cli import utils as cli_utils +from modelarrayio.cli.parser_utils import ( + _is_file, + add_hdf5_scalar_export_args, + add_log_level_arg, +) +from modelarrayio.utils.fixels import mif_to_nifti2, nifti2_to_mif + +logger = logging.getLogger(__name__) + + +def h5_export_mif_file( + in_file, + scalar_name, + output_file, + example_mif, + column_index=None, + source_file=None, +): + row = cli_utils.load_hdf5_scalar_row( + in_file, + scalar_name, + column_index=column_index, + source_file=source_file, + ) + template_nifti2, _ = mif_to_nifti2(example_mif) + out_nifti2 = nb.Nifti2Image( + row.reshape(-1, 1, 1), + affine=template_nifti2.affine, + header=template_nifti2.header, + ) + out_path = cli_utils.prepare_output_parent(output_file) + nifti2_to_mif(out_nifti2, out_path) + + +def h5_export_mif_file_main(**kwargs): + log_level = kwargs.pop('log_level', 'INFO') + cli_utils.configure_logging(log_level) + h5_export_mif_file(**kwargs) + return 0 + + +def _parse_h5_export_mif_file(): + parser = argparse.ArgumentParser( + description='Export one row from scalars//values to a MIF file', + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + is_file = partial(_is_file, parser=parser) + add_hdf5_scalar_export_args(parser) + parser.add_argument( + '--example-mif', + '--example_mif', + required=True, + type=is_file, + help='Path to an example MIF file used as output template.', + ) + add_log_level_arg(parser) + return parser diff --git a/src/modelarrayio/cli/h5_export_nifti_file.py b/src/modelarrayio/cli/h5_export_nifti_file.py new file mode 100644 index 0000000..3cfacfd --- /dev/null +++ b/src/modelarrayio/cli/h5_export_nifti_file.py @@ -0,0 +1,74 @@ +"""Export one scalar matrix row from HDF5 to a NIfTI file.""" + +from __future__ import annotations + +import argparse +import logging +from functools import partial + +import nibabel as nb +import numpy as np + +from modelarrayio.cli import utils as cli_utils +from modelarrayio.cli.parser_utils import ( + _is_file, + add_hdf5_scalar_export_args, + add_log_level_arg, +) + +logger = logging.getLogger(__name__) + + +def h5_export_nifti_file( + in_file, + scalar_name, + output_file, + group_mask_file, + column_index=None, + source_file=None, +): + row = cli_utils.load_hdf5_scalar_row( + in_file, + scalar_name, + column_index=column_index, + source_file=source_file, + ) + group_mask_img = nb.load(group_mask_file) + group_mask_matrix = group_mask_img.get_fdata() > 0 + num_voxels = int(group_mask_matrix.sum()) + if row.shape[0] != num_voxels: + raise ValueError( + f'Scalar row length ({row.shape[0]}) does not match group mask voxels ({num_voxels}).' + ) + + output = np.zeros(group_mask_matrix.shape, dtype=np.float32) + output[group_mask_matrix] = row.astype(np.float32) + header = group_mask_img.header.copy() + header.set_data_dtype(np.float32) + out_path = cli_utils.prepare_output_parent(output_file) + nb.Nifti1Image(output, affine=group_mask_img.affine, header=header).to_filename(out_path) + + +def h5_export_nifti_file_main(**kwargs): + log_level = kwargs.pop('log_level', 'INFO') + cli_utils.configure_logging(log_level) + h5_export_nifti_file(**kwargs) + return 0 + + +def _parse_h5_export_nifti_file(): + parser = argparse.ArgumentParser( + description='Export one row from scalars//values to a NIfTI file', + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + is_file = partial(_is_file, parser=parser) + add_hdf5_scalar_export_args(parser) + parser.add_argument( + '--group-mask-file', + '--group_mask_file', + required=True, + type=is_file, + help='Path to the group mask file used for original voxel flattening.', + ) + add_log_level_arg(parser) + return parser diff --git a/src/modelarrayio/cli/main.py b/src/modelarrayio/cli/main.py index 9cd26a9..ef1045b 100644 --- a/src/modelarrayio/cli/main.py +++ b/src/modelarrayio/cli/main.py @@ -6,6 +6,18 @@ from importlib.metadata import PackageNotFoundError, version from modelarrayio.cli.cifti_to_h5 import _parse_cifti_to_h5, cifti_to_h5_main +from modelarrayio.cli.h5_export_cifti_file import ( + _parse_h5_export_cifti_file, + h5_export_cifti_file_main, +) +from modelarrayio.cli.h5_export_mif_file import ( + _parse_h5_export_mif_file, + h5_export_mif_file_main, +) +from modelarrayio.cli.h5_export_nifti_file import ( + _parse_h5_export_nifti_file, + h5_export_nifti_file_main, +) from modelarrayio.cli.h5_to_cifti import _parse_h5_to_cifti, h5_to_cifti_main from modelarrayio.cli.h5_to_mif import _parse_h5_to_mif, h5_to_mif_main from modelarrayio.cli.h5_to_nifti import _parse_h5_to_nifti, h5_to_nifti_main @@ -19,6 +31,9 @@ ('h5-to-mif', _parse_h5_to_mif, h5_to_mif_main), ('h5-to-nifti', _parse_h5_to_nifti, h5_to_nifti_main), ('h5-to-cifti', _parse_h5_to_cifti, h5_to_cifti_main), + ('h5-export-mif-file', _parse_h5_export_mif_file, h5_export_mif_file_main), + ('h5-export-nifti-file', _parse_h5_export_nifti_file, h5_export_nifti_file_main), + ('h5-export-cifti-file', _parse_h5_export_cifti_file, h5_export_cifti_file_main), ] diff --git a/src/modelarrayio/cli/mif_to_h5.py b/src/modelarrayio/cli/mif_to_h5.py index 38ed1cc..1932c4d 100644 --- a/src/modelarrayio/cli/mif_to_h5.py +++ b/src/modelarrayio/cli/mif_to_h5.py @@ -12,8 +12,9 @@ import pandas as pd from tqdm import tqdm +from modelarrayio.cli import diagnostics as cli_diagnostics from modelarrayio.cli import utils as cli_utils -from modelarrayio.cli.parser_utils import _is_file, add_to_modelarray_args +from modelarrayio.cli.parser_utils import _is_file, add_diagnostics_args, add_to_modelarray_args from modelarrayio.utils.fixels import gather_fixels, mif_to_nifti2 logger = logging.getLogger(__name__) @@ -33,6 +34,9 @@ def mif_to_h5( target_chunk_mb=2.0, workers=None, s3_workers=1, + no_diagnostics=False, + diagnostics_dir=None, + diagnostic_maps=None, ): """Load all fixeldb data and write to an HDF5 or TileDB file. @@ -67,6 +71,12 @@ def mif_to_h5( Has no effect when ``backend='hdf5'``. s3_workers : :obj:`int` Number of parallel workers for S3 downloads. Default 1. + no_diagnostics : :obj:`bool` + Disable diagnostic outputs in native format. + diagnostics_dir : :obj:`str` or :obj:`None` + Output directory for diagnostics. Defaults to ``_diagnostics``. + diagnostic_maps : :obj:`list` or :obj:`None` + Diagnostic maps to write. Supported: ``mean``, ``element_id``, ``n_non_nan``. Returns ------- @@ -79,17 +89,39 @@ def mif_to_h5( # gather cohort data cohort_df = pd.read_csv(cohort_file) + maps_to_write = cli_utils.normalize_diagnostic_maps(diagnostic_maps) # upload each cohort's data scalars = defaultdict(list) sources_lists = defaultdict(list) + template_nifti2 = None logger.info('Extracting .mif data...') for row in tqdm(cohort_df.itertuples(index=False), total=cohort_df.shape[0]): scalar_file = row.source_file - _scalar_img, scalar_data = mif_to_nifti2(scalar_file) + scalar_img, scalar_data = mif_to_nifti2(scalar_file) + if template_nifti2 is None: + template_nifti2 = scalar_img scalars[row.scalar_name].append(scalar_data) sources_lists[row.scalar_name].append(row.source_file) + if not no_diagnostics: + output_diag_dir = ( + Path(diagnostics_dir) + if diagnostics_dir is not None + else cli_utils.default_diagnostics_dir(output_path) + ) + output_diag_dir.mkdir(parents=True, exist_ok=True) + for scalar_name, rows in scalars.items(): + cli_diagnostics.verify_mif_element_mapping(template_nifti2, rows[0].shape[0]) + diagnostics = cli_diagnostics.summarize_rows(rows) + cli_diagnostics.write_mif_diagnostics( + maps=maps_to_write, + scalar_name=scalar_name, + diagnostics=diagnostics, + template_nifti2=template_nifti2, + output_dir=output_diag_dir, + ) + # Write the output if backend == 'hdf5': output_path = cli_utils.prepare_output_parent(output_path) @@ -155,4 +187,5 @@ def _parse_mif_to_h5(): # Common arguments add_to_modelarray_args(parser, default_output='fixelarray.h5') + add_diagnostics_args(parser) return parser diff --git a/src/modelarrayio/cli/nifti_to_h5.py b/src/modelarrayio/cli/nifti_to_h5.py index 28e2422..39d2e69 100644 --- a/src/modelarrayio/cli/nifti_to_h5.py +++ b/src/modelarrayio/cli/nifti_to_h5.py @@ -12,8 +12,9 @@ import numpy as np import pandas as pd +from modelarrayio.cli import diagnostics as cli_diagnostics from modelarrayio.cli import utils as cli_utils -from modelarrayio.cli.parser_utils import _is_file, add_to_modelarray_args +from modelarrayio.cli.parser_utils import _is_file, add_diagnostics_args, add_to_modelarray_args from modelarrayio.utils.voxels import _load_cohort_voxels logger = logging.getLogger(__name__) @@ -32,6 +33,9 @@ def nifti_to_h5( target_chunk_mb=2.0, workers=None, s3_workers=1, + no_diagnostics=False, + diagnostics_dir=None, + diagnostic_maps=None, ): """Load all volume data and write to an HDF5 or TileDB file. @@ -64,6 +68,12 @@ def nifti_to_h5( Has no effect when ``backend='hdf5'``. s3_workers : :obj:`int` Number of parallel workers for S3 downloads. Default 1. + no_diagnostics : :obj:`bool` + Disable diagnostic outputs in native format. + diagnostics_dir : :obj:`str` or :obj:`None` + Output directory for diagnostics. Defaults to ``_diagnostics``. + diagnostic_maps : :obj:`list` or :obj:`None` + Diagnostic maps to write. Supported: ``mean``, ``element_id``, ``n_non_nan``. """ cohort_df = pd.read_csv(cohort_file) output_path = Path(output) @@ -83,6 +93,26 @@ def nifti_to_h5( logger.info('Extracting NIfTI data...') scalars, sources_lists = _load_cohort_voxels(cohort_df, group_mask_matrix, s3_workers) + maps_to_write = cli_utils.normalize_diagnostic_maps(diagnostic_maps) + + if not no_diagnostics: + output_diag_dir = ( + Path(diagnostics_dir) + if diagnostics_dir is not None + else cli_utils.default_diagnostics_dir(output_path) + ) + output_diag_dir.mkdir(parents=True, exist_ok=True) + cli_diagnostics.verify_nifti_element_mapping(group_mask_img, group_mask_matrix) + for scalar_name, rows in scalars.items(): + diagnostics = cli_diagnostics.summarize_rows(rows) + cli_diagnostics.write_nifti_diagnostics( + maps=maps_to_write, + scalar_name=scalar_name, + diagnostics=diagnostics, + group_mask_img=group_mask_img, + group_mask_matrix=group_mask_matrix, + output_dir=output_diag_dir, + ) if backend == 'hdf5': output_path = cli_utils.prepare_output_parent(output_path) @@ -140,4 +170,5 @@ def _parse_nifti_to_h5(): # Common arguments add_to_modelarray_args(parser, default_output='voxelarray.h5') + add_diagnostics_args(parser) return parser diff --git a/src/modelarrayio/cli/parser_utils.py b/src/modelarrayio/cli/parser_utils.py index 55c2c4a..2cf6a9b 100644 --- a/src/modelarrayio/cli/parser_utils.py +++ b/src/modelarrayio/cli/parser_utils.py @@ -135,6 +135,35 @@ def add_log_level_arg(parser): return parser +def add_diagnostics_args(parser): + parser.add_argument( + '--no-diagnostics', + action='store_true', + help='Disable writing conversion diagnostics in the native imaging format.', + default=False, + ) + parser.add_argument( + '--diagnostics-dir', + '--diagnostics_dir', + help=( + 'Directory for diagnostic outputs. Defaults to _diagnostics ' + 'next to --output.' + ), + default=None, + ) + parser.add_argument( + '--diagnostic-maps', + '--diagnostic_maps', + nargs='+', + default=['mean', 'element_id', 'n_non_nan'], + help=( + 'Diagnostic maps to write. Supports space-separated values and/or comma-separated ' + 'tokens. Valid values: mean, element_id, n_non_nan.' + ), + ) + return parser + + def add_from_modelarray_args(parser): parser.add_argument( '--analysis-name', @@ -163,6 +192,51 @@ def add_from_modelarray_args(parser): return parser +def add_hdf5_scalar_export_args(parser): + parser.add_argument( + '--input-hdf5', + '--input_hdf5', + dest='in_file', + help='Path to an HDF5 (.h5) file containing scalar matrices under scalars//values.', + type=partial(_is_file, parser=parser), + required=True, + ) + parser.add_argument( + '--scalar-name', + '--scalar_name', + required=True, + help='Scalar name under /scalars in the input HDF5 file.', + ) + parser.add_argument( + '--output-file', + '--output_file', + required=True, + help='Output file path in the native imaging format.', + ) + + selection = parser.add_mutually_exclusive_group(required=True) + selection.add_argument( + '--column-index', + '--column_index', + '--subject-index', + '--subject_index', + type=int, + dest='column_index', + help='0-based row index within scalars//values.', + default=None, + ) + selection.add_argument( + '--source-file', + '--source_file', + help=( + 'Original source path to match against scalars//column_names ' + 'for row selection.' + ), + default=None, + ) + return parser + + def _path_exists(path: str | Path | None, parser) -> Path: """Ensure a given path exists.""" if path is None or not Path(path).exists(): diff --git a/src/modelarrayio/cli/utils.py b/src/modelarrayio/cli/utils.py index 36918b4..74c51e2 100644 --- a/src/modelarrayio/cli/utils.py +++ b/src/modelarrayio/cli/utils.py @@ -37,6 +37,35 @@ def prepare_output_parent(output_file: str | Path) -> Path: return output_path +def normalize_diagnostic_maps(values: Sequence[str] | None) -> list[str]: + """Normalize diagnostic map names from CLI input.""" + if not values: + return ['mean', 'element_id', 'n_non_nan'] + + maps: list[str] = [] + valid = {'mean', 'element_id', 'n_non_nan'} + for value in values: + for token in str(value).split(','): + name = token.strip() + if not name: + continue + if name not in valid: + valid_str = ', '.join(sorted(valid)) + raise ValueError(f'Invalid diagnostic map {name!r}. Expected one of: {valid_str}.') + if name not in maps: + maps.append(name) + if not maps: + return ['mean', 'element_id', 'n_non_nan'] + return maps + + +def default_diagnostics_dir(output: str | Path) -> Path: + """Return the default diagnostics directory for a conversion output path.""" + output_path = Path(output) + stem = output_path.stem if output_path.suffix else output_path.name + return output_path.parent / f'{stem}_diagnostics' + + def write_table_dataset( h5_file: h5py.File, dataset_name: str, @@ -167,6 +196,31 @@ def read_result_names( return [f'component{n + 1:03d}' for n in range(results_matrix.shape[0])] +def load_hdf5_scalar_row( + in_file: str | Path, + scalar_name: str, + *, + column_index: int | None = None, + source_file: str | None = None, +) -> np.ndarray: + """Load one row from scalars//values in an HDF5 file.""" + with h5py.File(in_file, 'r') as h5_data: + dataset_path = f'scalars/{scalar_name}/values' + if dataset_path not in h5_data: + raise ValueError(f'Scalar dataset not found: {dataset_path}') + dataset = h5_data[dataset_path] + num_subjects = dataset.shape[0] + + resolved_index = _resolve_column_index( + h5_data, + scalar_name, + num_subjects=num_subjects, + column_index=column_index, + source_file=source_file, + ) + return np.asarray(dataset[resolved_index, :], dtype=np.float32) + + def _decode_names(values: object) -> list[str]: if isinstance(values, np.ndarray): sequence = values.tolist() @@ -183,3 +237,43 @@ def _decode_names(values: object) -> list[str]: text = str(value) decoded.append(text.rstrip('\x00').strip()) return [name for name in decoded if name] + + +def _resolve_column_index( + h5_data: h5py.File, + scalar_name: str, + *, + num_subjects: int, + column_index: int | None, + source_file: str | None, +) -> int: + if column_index is not None: + if column_index < 0 or column_index >= num_subjects: + raise ValueError( + f'--column-index {column_index} is out of bounds for scalar ' + f'{scalar_name!r} with {num_subjects} subjects.' + ) + return int(column_index) + + if source_file is None: + raise ValueError('Provide one of --column-index or --source-file.') + + names_path = f'scalars/{scalar_name}/column_names' + if names_path not in h5_data: + raise ValueError( + f'Column names dataset not found for scalar {scalar_name!r} at {names_path}. ' + 'Use --column-index instead.' + ) + + names = _decode_names(h5_data[names_path][()]) + matches = [idx for idx, name in enumerate(names) if name == source_file] + if not matches: + raise ValueError( + f'--source-file {source_file!r} was not found in scalars/{scalar_name}/column_names.' + ) + if len(matches) > 1: + raise ValueError( + f'--source-file {source_file!r} matched multiple rows ({matches}) for scalar ' + f'{scalar_name!r}. Use --column-index instead.' + ) + return matches[0] diff --git a/test/test_cifti_cli.py b/test/test_cifti_cli.py index 2fb3fba..0b6db7b 100644 --- a/test/test_cifti_cli.py +++ b/test/test_cifti_cli.py @@ -51,6 +51,7 @@ def test_concifti_cli_creates_expected_hdf5(tmp_path, monkeypatch): ) out_h5 = tmp_path / 'out_cifti.h5' + diag_dir = tmp_path / 'out_cifti_diagnostics' monkeypatch.chdir(tmp_path) assert ( modelarrayio_main( @@ -75,6 +76,9 @@ def test_concifti_cli_creates_expected_hdf5(tmp_path, monkeypatch): == 0 ) assert op.exists(out_h5) + assert (diag_dir / 'THICK_mean.dscalar.nii').exists() + assert (diag_dir / 'THICK_element_id.dscalar.nii').exists() + assert (diag_dir / 'THICK_n_non_nan.dscalar.nii').exists() # Validate HDF5 contents with h5py.File(out_h5, 'r') as h5: @@ -107,3 +111,38 @@ def test_concifti_cli_creates_expected_hdf5(tmp_path, monkeypatch): # Spot-check a couple values assert np.isclose(float(dset[0, 0]), 0.0) assert np.isclose(float(dset[1, 0]), 1.0) + + +def test_concifti_cli_no_diagnostics_disables_outputs(tmp_path, monkeypatch): + vol_shape = (2, 2, 2) + mask = np.zeros(vol_shape, dtype=bool) + mask[0, 0, 0] = True + mask[1, 1, 1] = True + + path = tmp_path / 'sub-1.dscalar.nii' + _make_synthetic_cifti_dscalar(mask, np.array([1.0, 2.0], dtype=np.float32)).to_filename(path) + + cohort_csv = tmp_path / 'cohort_cifti.csv' + with cohort_csv.open('w', newline='') as f: + writer = csv.DictWriter(f, fieldnames=['scalar_name', 'source_file']) + writer.writeheader() + writer.writerow({'scalar_name': 'THICK', 'source_file': path.name}) + + out_h5 = tmp_path / 'out_cifti.h5' + diag_dir = tmp_path / 'out_cifti_diagnostics' + monkeypatch.chdir(tmp_path) + assert ( + modelarrayio_main( + [ + 'cifti-to-h5', + '--cohort-file', + str(cohort_csv), + '--output', + str(out_h5), + '--no-diagnostics', + ] + ) + == 0 + ) + assert out_h5.exists() + assert not diag_dir.exists() diff --git a/test/test_diagnostics.py b/test/test_diagnostics.py new file mode 100644 index 0000000..d456684 --- /dev/null +++ b/test/test_diagnostics.py @@ -0,0 +1,25 @@ +"""Unit tests for conversion diagnostics helpers.""" + +from __future__ import annotations + +import nibabel as nb +import numpy as np +import pytest + +from modelarrayio.cli import diagnostics as cli_diagnostics + + +def test_verify_nifti_element_mapping_raises_on_mismatch(monkeypatch): + group_mask_matrix = np.zeros((2, 2, 2), dtype=bool) + group_mask_matrix[0, 0, 0] = True + group_mask_matrix[1, 1, 1] = True + group_mask_img = nb.Nifti1Image(group_mask_matrix.astype(np.uint8), affine=np.eye(4)) + + monkeypatch.setattr( + cli_diagnostics, + 'flattened_image', + lambda *_args, **_kwargs: np.array([1.0, 0.0], dtype=np.float32), + ) + + with pytest.raises(ValueError, match='Element ID mapping check failed'): + cli_diagnostics.verify_nifti_element_mapping(group_mask_img, group_mask_matrix) diff --git a/test/test_h5_export_cli.py b/test/test_h5_export_cli.py new file mode 100644 index 0000000..a2ffdf6 --- /dev/null +++ b/test/test_h5_export_cli.py @@ -0,0 +1,183 @@ +"""Tests for h5-export-*-file commands.""" + +from __future__ import annotations + +import h5py +import nibabel as nb +import numpy as np +from nibabel.cifti2.cifti2_axes import BrainModelAxis, ScalarAxis + +from modelarrayio.cli import h5_export_mif_file as export_mif_cli +from modelarrayio.cli.main import main as modelarrayio_main + + +def _make_nifti(data): + return nb.Nifti1Image(data.astype(np.float32), affine=np.eye(4)) + + +def _make_synthetic_cifti(mask_bool: np.ndarray, values: np.ndarray) -> nb.Cifti2Image: + scalar_axis = ScalarAxis(['synthetic']) + brain_axis = BrainModelAxis.from_mask(mask_bool) + header = nb.cifti2.Cifti2Header.from_axes((scalar_axis, brain_axis)) + data_2d = values.reshape(1, -1).astype(np.float32) + return nb.Cifti2Image(data_2d, header=header) + + +def test_h5_export_nifti_file_cli_column_index_and_source_file(tmp_path): + shape = (3, 3, 3) + group_mask = np.zeros(shape, dtype=bool) + coords = [(0, 0, 0), (1, 1, 1), (2, 2, 2)] + for coord in coords: + group_mask[coord] = True + group_mask_file = tmp_path / 'group_mask.nii.gz' + _make_nifti(group_mask.astype(np.uint8)).to_filename(group_mask_file) + + in_file = tmp_path / 'input.h5' + with h5py.File(in_file, 'w') as h5: + h5.create_dataset( + 'scalars/FA/values', + data=np.array( + [ + [10.0, 20.0, 30.0], + [11.0, 21.0, np.nan], + ], + dtype=np.float32, + ), + ) + h5.create_dataset( + 'scalars/FA/column_names', + data=np.array(['sub-01_scalar.nii.gz', 'sub-02_scalar.nii.gz'], dtype=object), + dtype=h5py.string_dtype('utf-8'), + ) + + out_column_index = tmp_path / 'column_index.nii.gz' + assert ( + modelarrayio_main( + [ + 'h5-export-nifti-file', + '--input-hdf5', + str(in_file), + '--scalar-name', + 'FA', + '--column-index', + '1', + '--group-mask-file', + str(group_mask_file), + '--output-file', + str(out_column_index), + ] + ) + == 0 + ) + out_data = nb.load(out_column_index).get_fdata() + assert out_data[coords[0]] == 11.0 + assert out_data[coords[1]] == 21.0 + assert np.isnan(out_data[coords[2]]) + + out_source_file = tmp_path / 'source_file.nii.gz' + assert ( + modelarrayio_main( + [ + 'h5-export-nifti-file', + '--input-hdf5', + str(in_file), + '--scalar-name', + 'FA', + '--source-file', + 'sub-01_scalar.nii.gz', + '--group-mask-file', + str(group_mask_file), + '--output-file', + str(out_source_file), + ] + ) + == 0 + ) + out_data_source = nb.load(out_source_file).get_fdata() + assert out_data_source[coords[0]] == 10.0 + assert out_data_source[coords[1]] == 20.0 + assert out_data_source[coords[2]] == 30.0 + + +def test_h5_export_cifti_file_cli(tmp_path): + mask = np.zeros((2, 2, 2), dtype=bool) + mask[0, 0, 0] = True + mask[1, 1, 1] = True + template = tmp_path / 'template.dscalar.nii' + _make_synthetic_cifti(mask, np.array([0.0, 0.0], dtype=np.float32)).to_filename(template) + + in_file = tmp_path / 'input.h5' + with h5py.File(in_file, 'w') as h5: + h5.create_dataset( + 'scalars/THICK/values', + data=np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32), + ) + + out_file = tmp_path / 'exported.dscalar.nii' + assert ( + modelarrayio_main( + [ + 'h5-export-cifti-file', + '--input-hdf5', + str(in_file), + '--scalar-name', + 'THICK', + '--column-index', + '1', + '--example-cifti', + str(template), + '--output-file', + str(out_file), + ] + ) + == 0 + ) + exported = nb.load(out_file).get_fdata().squeeze() + assert np.array_equal(exported, np.array([3.0, 4.0], dtype=np.float64)) + + +def test_h5_export_mif_file_cli(monkeypatch, tmp_path): + in_file = tmp_path / 'input.h5' + with h5py.File(in_file, 'w') as h5: + h5.create_dataset( + 'scalars/FD/values', + data=np.array([[5.0, 6.0, 7.0], [8.0, 9.0, 10.0]], dtype=np.float32), + ) + + example_mif = tmp_path / 'example.mif' + example_mif.write_text('dummy') + out_file = tmp_path / 'subject.mif' + + captured = {} + + def _fake_mif_to_nifti2(_path): + template = nb.Nifti2Image(np.zeros((3, 1, 1), dtype=np.float32), affine=np.eye(4)) + return template, np.zeros(3, dtype=np.float32) + + def _fake_nifti2_to_mif(nifti2_image, mif_file): + captured['data'] = nifti2_image.get_fdata().squeeze().copy() + mif_file.write_text('fake mif') + + monkeypatch.setattr(export_mif_cli, 'mif_to_nifti2', _fake_mif_to_nifti2) + monkeypatch.setattr(export_mif_cli, 'nifti2_to_mif', _fake_nifti2_to_mif) + + assert ( + modelarrayio_main( + [ + 'h5-export-mif-file', + '--input-hdf5', + str(in_file), + '--scalar-name', + 'FD', + '--column-index', + '1', + '--example-mif', + str(example_mif), + '--output-file', + str(out_file), + ] + ) + == 0 + ) + assert out_file.exists() + assert np.array_equal(captured['data'], np.array([8.0, 9.0, 10.0], dtype=np.float64)) diff --git a/test/test_voxels_cli.py b/test/test_voxels_cli.py index c75acf7..808b06c 100644 --- a/test/test_voxels_cli.py +++ b/test/test_voxels_cli.py @@ -71,6 +71,7 @@ def test_convoxel_cli_creates_expected_hdf5(tmp_path, monkeypatch): ) out_h5 = tmp_path / 'out.h5' + diag_dir = tmp_path / 'out_diagnostics' monkeypatch.chdir(tmp_path) assert ( modelarrayio_main( @@ -97,6 +98,9 @@ def test_convoxel_cli_creates_expected_hdf5(tmp_path, monkeypatch): == 0 ) assert op.exists(out_h5) + assert (diag_dir / 'FA_mean.nii.gz').exists() + assert (diag_dir / 'FA_element_id.nii.gz').exists() + assert (diag_dir / 'FA_n_non_nan.nii.gz').exists() # Validate HDF5 contents with h5py.File(out_h5, 'r') as h5: @@ -143,6 +147,58 @@ def test_convoxel_cli_creates_expected_hdf5(tmp_path, monkeypatch): assert np.isclose(v1, expected_s1, equal_nan=True) +def test_convoxel_cli_no_diagnostics_disables_outputs(tmp_path, monkeypatch): + shape = (3, 3, 3) + group_mask = np.zeros(shape, dtype=bool) + group_mask[0, 0, 0] = True + group_mask[1, 1, 1] = True + + group_mask_file = tmp_path / 'group_mask.nii.gz' + _make_nifti(group_mask.astype(np.uint8)).to_filename(group_mask_file) + + scalar = np.zeros(shape, dtype=np.float32) + scalar[0, 0, 0] = 1.0 + scalar[1, 1, 1] = 2.0 + scalar_path = tmp_path / 'sub-1_scalar.nii.gz' + _make_nifti(scalar).to_filename(scalar_path) + + mask_path = tmp_path / 'sub-1_mask.nii.gz' + _make_nifti(group_mask.astype(np.uint8)).to_filename(mask_path) + + cohort_csv = tmp_path / 'cohort.csv' + with cohort_csv.open('w', newline='') as f: + writer = csv.DictWriter(f, fieldnames=['scalar_name', 'source_file', 'source_mask_file']) + writer.writeheader() + writer.writerow( + { + 'scalar_name': 'FA', + 'source_file': scalar_path.name, + 'source_mask_file': mask_path.name, + } + ) + + out_h5 = tmp_path / 'out.h5' + diag_dir = tmp_path / 'out_diagnostics' + monkeypatch.chdir(tmp_path) + assert ( + modelarrayio_main( + [ + 'nifti-to-h5', + '--group-mask-file', + str(group_mask_file), + '--cohort-file', + str(cohort_csv), + '--output', + str(out_h5), + '--no-diagnostics', + ] + ) + == 0 + ) + assert out_h5.exists() + assert not diag_dir.exists() + + def test_h5_to_nifti_cli_writes_results_with_dataset_column_names(tmp_path): shape = (3, 3, 3) group_mask = np.zeros(shape, dtype=bool)