From dc73009cc19dad22183cfb02c240c3cb69920841 Mon Sep 17 00:00:00 2001 From: nprzrosas Date: Tue, 9 Jun 2026 20:22:52 -0400 Subject: [PATCH] Add model-selection strategies and local BioModels batch mode --- README.md | 43 ++ biomodels_batch.py | 398 ++++++++++++++++-- data2sbml.py | 761 ++++++++++++++++++++++++++++++++-- tests/test_biomodels_batch.py | 85 +++- tests/test_model_selection.py | 218 ++++++++++ 5 files changed, 1434 insertions(+), 71 deletions(-) create mode 100644 tests/test_model_selection.py diff --git a/README.md b/README.md index 96a6d96..0ced877 100644 --- a/README.md +++ b/README.md @@ -40,6 +40,26 @@ uv run python data2sbml.py \ --maxsize 24 ``` +### Model-selection strategies + +The final inferred model can now be selected with three strategies ordered from +cheapest to most expensive: + +- `--selection-strategy pysr_model_selection`: choose one safe equation per species using PySR's local `model_selection` rule. +- `--selection-strategy global_rmse`: shortlist per-species candidates, reconstruct full models, and choose the model with the lowest simulated mean RMSE. +- `--selection-strategy global_multiobjective`: search full reconstructed models and choose the model with the lowest `rmse_mean + penalty * total_complexity`. + +Example: + +```bash +uv run python data2sbml.py \ + --sbml demo_polynomial_2d.xml \ + --species x y \ + --selection-strategy global_rmse \ + --selection-top-k 3 \ + --selection-max-model-evals 64 +``` + ### Harder example The BioModels file `BIOMD0000000346_url.xml` is also runnable with the same pipeline, but it is a harder symbolic-regression problem and should be treated as a stress test rather than the baseline validation case. @@ -52,11 +72,34 @@ To fetch curated BioModels identifiers via `biocompose`, cache their SBML/SED-ML uv run python biomodels_batch.py --limit 25 --max-species 6 ``` +To run the batch directly over a local directory of downloaded BioModels files such as `artifacts/all_curated_biomodels`: + +```bash +uv run python biomodels_batch.py \ + --local-models-dir artifacts/all_curated_biomodels \ + --limit 25 \ + --max-species 6 +``` + +When a model has no `.sedml` file, or the SED-ML does not expose a usable `UniformTimeCourse`, the batch now falls back to configurable simulation settings instead of skipping the model: + +```bash +uv run python biomodels_batch.py \ + --local-models-dir artifacts/all_curated_biomodels \ + --fallback-start 0 \ + --fallback-duration 20 \ + --fallback-points 801 +``` + +Use `--require-sedml` if you want to keep the previous strict behavior and fail on models without a usable SED-ML file. + The batch summary is written to `artifacts/biomodels_batch_summary.tsv` and now includes: - `original_model_equation_count` - `inferred_model_pysr_complexity` - `inferred_model_node_count` +- `simulation_settings_source` +- `simulation_settings_note` To only cache all curated BioModels files without running symbolic regression: diff --git a/biomodels_batch.py b/biomodels_batch.py index e3c7dab..74ca7a0 100644 --- a/biomodels_batch.py +++ b/biomodels_batch.py @@ -7,6 +7,7 @@ import json import subprocess import sys +from dataclasses import dataclass from pathlib import Path from types import ModuleType from typing import Any @@ -16,20 +17,27 @@ DEFAULT_LIMIT = 10 DEFAULT_MAX_POINTS = 1001 DEFAULT_MAX_SPECIES = 6 +DEFAULT_FALLBACK_START = 0.0 +DEFAULT_FALLBACK_DURATION = 20.0 +DEFAULT_FALLBACK_POINTS = 801 ROOT_DIR = Path(__file__).resolve().parent REPO_ROOT = ROOT_DIR.parent PIPELINE_SCRIPT = ROOT_DIR / "data2sbml.py" DEFAULT_ARTIFACTS_DIR = ROOT_DIR / "artifacts" DEFAULT_DOWNLOAD_DIR = DEFAULT_ARTIFACTS_DIR / "downloaded_biomodels" +DEFAULT_LOCAL_MODELS_DIR = DEFAULT_ARTIFACTS_DIR / "all_curated_biomodels" DEFAULT_SUMMARY_PATH = DEFAULT_ARTIFACTS_DIR / "biomodels_batch_summary.tsv" SUMMARY_COLUMNS = [ "model_id", "status", "exit_code", + "start", "utc_duration", "utc_points", "duration", "points", + "simulation_settings_source", + "simulation_settings_note", "original_model_equation_count", "selected_species_count", "inferred_equation_count", @@ -45,6 +53,48 @@ ] +@dataclass(frozen=True) +class LoadedBiomodel: + """Represent one BioModels asset bundle resolved for the batch runner. + + Args: + biomodel_id: Model identifier or local directory name. + sbml_path: Resolved SBML file used by the pipeline. + sedml_path: Optional SED-ML file associated with the model. + utc: Optional UniformTimeCourse specification parsed from SED-ML. + simulation_note: Short note describing any SED-ML fallback condition. + """ + + biomodel_id: str + sbml_path: Path + sedml_path: Path | None + utc: Any | None + simulation_note: str = "" + + +@dataclass(frozen=True) +class SimulationSettings: + """Store effective simulation settings passed to ``data2sbml.py``. + + Args: + start: Simulation start time. + duration: Simulated duration. + points: Number of sampled time points. + source: Human-readable source of the settings. + note: Short note describing the selection or fallback reason. + utc_duration: Duration requested by SED-ML when available. + utc_points: Point count requested by SED-ML when available. + """ + + start: float + duration: float + points: int + source: str + note: str + utc_duration: float | str = "" + utc_points: int | str = "" + + def parse_args() -> argparse.Namespace: """Parse command-line arguments for the BioModels batch runner. @@ -105,6 +155,14 @@ def parse_args() -> argparse.Namespace: default=str(DEFAULT_DOWNLOAD_DIR), help="Directory where downloaded SBML and SED-ML files are cached.", ) + parser.add_argument( + "--local-models-dir", + help=( + "Optional directory containing local BioModels subdirectories with " + "SBML files and optional SED-ML files. When provided, remote " + "fetching is bypassed." + ), + ) parser.add_argument( "--summary-path", default=str(DEFAULT_SUMMARY_PATH), @@ -126,12 +184,44 @@ def parse_args() -> argparse.Namespace: type=int, help="Optional PySR maxsize override.", ) + parser.add_argument( + "--fallback-start", + type=float, + default=DEFAULT_FALLBACK_START, + help=( + "Start time used when no usable SED-ML UniformTimeCourse is " + f"available (default: {DEFAULT_FALLBACK_START})." + ), + ) + parser.add_argument( + "--fallback-duration", + type=float, + default=DEFAULT_FALLBACK_DURATION, + help=( + "Duration used when no usable SED-ML UniformTimeCourse is " + f"available (default: {DEFAULT_FALLBACK_DURATION})." + ), + ) + parser.add_argument( + "--fallback-points", + type=int, + default=DEFAULT_FALLBACK_POINTS, + help=( + "Point count used when no usable SED-ML UniformTimeCourse is " + f"available (default: {DEFAULT_FALLBACK_POINTS})." + ), + ) parser.add_argument( "--bad-fit-rmse-threshold", type=float, default=0.1, help="RMSE threshold forwarded to data2sbml.py.", ) + parser.add_argument( + "--require-sedml", + action="store_true", + help="Fail on models that do not include a usable SED-ML file.", + ) parser.add_argument( "--curated-only", action="store_true", @@ -179,6 +269,35 @@ def select_model_ids(args: argparse.Namespace, module: ModuleType) -> list[str]: ) +def select_local_model_ids( + args: argparse.Namespace, + local_models_dir: Path, +) -> list[str]: + """Resolve model identifiers from a local directory of BioModels folders. + + Args: + args: Parsed CLI arguments. + local_models_dir: Root directory containing one subdirectory per model. + + Returns: + Ordered list of local model identifiers. + """ + + if args.model_id: + return list(dict.fromkeys(args.model_id)) + + model_ids = sorted( + child.name for child in local_models_dir.iterdir() if child.is_dir() + ) + if args.curated_only: + model_ids = [model_id for model_id in model_ids if model_id.startswith("BIOMD")] + if args.offset: + model_ids = model_ids[args.offset :] + if args.limit > 0: + model_ids = model_ids[: args.limit] + return model_ids + + def resolve_species_ids(sbml_path: Path, max_species: int) -> list[str]: """Resolve floating species IDs and enforce a configurable size cap. @@ -206,6 +325,7 @@ def run_pipeline( sbml_path: Path, output_dir: Path, species_ids: list[str], + start: float, duration: float, points: int, config_path: Path, @@ -219,6 +339,7 @@ def run_pipeline( sbml_path: Input SBML file. output_dir: Directory where artifacts are written. species_ids: Floating species used by symbolic regression. + start: Simulation start time. duration: Simulated duration. points: Number of output samples. config_path: PySR config file path. @@ -237,6 +358,8 @@ def run_pipeline( str(sbml_path), "--output-dir", str(output_dir), + "--start", + str(start), "--duration", str(duration), "--points", @@ -257,6 +380,159 @@ def run_pipeline( return int(completed.returncode) +def load_local_biomodel( + *, + biomodel_id: str, + local_models_dir: Path, + module: ModuleType, + require_sedml: bool, +) -> LoadedBiomodel: + """Load one model from a local BioModels directory. + + Args: + biomodel_id: Local model directory name. + local_models_dir: Root directory containing per-model subdirectories. + module: Imported helper module with SED-ML utilities. + require_sedml: Whether a usable SED-ML file is mandatory. + + Returns: + Resolved local model bundle ready for simulation. + """ + + model_dir = local_models_dir / biomodel_id + if not model_dir.is_dir(): + raise ValueError(f"{biomodel_id}: local directory does not exist: {model_dir}") + + entry_files = [child for child in model_dir.iterdir() if child.is_file()] + sedml_entry = module.find_first_sedml(entry_files) + sbml_entry = module.find_first_sbml(entry_files) + if sbml_entry is None: + raise ValueError( + f"{biomodel_id}: could not find an SBML (.xml/.sbml) file in {model_dir}." + ) + if require_sedml and sedml_entry is None: + raise ValueError(f"{biomodel_id}: could not find a .sedml file in {model_dir}.") + + sbml_path = Path(sbml_entry).resolve() + sedml_path = Path(sedml_entry).resolve() if sedml_entry is not None else None + utc = None + simulation_note = "missing_sedml" if sedml_path is None else "" + + if sedml_path is not None: + try: + sed_doc = module.read_sedml_doc(str(sedml_path)) + try: + utc = module.extract_first_uniform_time_course(sed_doc) + except ValueError as exc: + simulation_note = str(exc) + resolved_sbml_path = Path( + module.resolve_sbml_source_from_sedml( + sed_doc, + str(model_dir.resolve()), + str(sbml_path), + ) + ).resolve() + if resolved_sbml_path.exists(): + sbml_path = resolved_sbml_path + except Exception as exc: + if require_sedml: + raise + simulation_note = f"sedml_unusable: {exc}" + + return LoadedBiomodel( + biomodel_id=biomodel_id, + sbml_path=sbml_path, + sedml_path=sedml_path, + utc=utc, + simulation_note=simulation_note, + ) + + +def load_remote_biomodel( + *, + biomodel_id: str, + module: ModuleType, + download_dir: Path, + require_sedml: bool, +) -> LoadedBiomodel: + """Load one model from the BioModels API into the local cache. + + Args: + biomodel_id: BioModels identifier to load. + module: Imported helper module with BioModels accessors. + download_dir: Cache directory for fetched artifacts. + require_sedml: Whether a usable SED-ML file is mandatory. + + Returns: + Resolved fetched model bundle ready for simulation. + """ + + metadata = module.biomodels.get_metadata(biomodel_id) + loaded_model = module.load_biomodel( + biomodel_id, + metadata, + stable_root=str(download_dir), + require_sedml=require_sedml, + require_utc=False, + ) + simulation_note = "" + if loaded_model.sedml_path is None: + simulation_note = "missing_sedml" + + return LoadedBiomodel( + biomodel_id=loaded_model.biomodel_id, + sbml_path=Path(loaded_model.sbml_path).resolve(), + sedml_path=( + Path(loaded_model.sedml_path).resolve() + if loaded_model.sedml_path is not None + else None + ), + utc=loaded_model.utc, + simulation_note=simulation_note, + ) + + +def resolve_simulation_settings( + *, + loaded_model: LoadedBiomodel, + max_points: int, + fallback_start: float, + fallback_duration: float, + fallback_points: int, +) -> SimulationSettings: + """Resolve effective simulation settings with an SED-ML fallback path. + + Args: + loaded_model: Loaded model bundle with optional UTC metadata. + max_points: Global cap for sampled points. + fallback_start: Start time used when SED-ML metadata is unavailable. + fallback_duration: Duration used when SED-ML metadata is unavailable. + fallback_points: Point count used when SED-ML metadata is unavailable. + + Returns: + Effective settings to pass into ``data2sbml.py``. + """ + + if loaded_model.utc is not None: + return SimulationSettings( + start=DEFAULT_FALLBACK_START, + duration=float(loaded_model.utc.duration), + points=min(int(loaded_model.utc.number_of_points), int(max_points)), + source="sedml_uniform_time_course", + note=loaded_model.simulation_note, + utc_duration=float(loaded_model.utc.duration), + utc_points=int(loaded_model.utc.number_of_points), + ) + + return SimulationSettings( + start=float(fallback_start), + duration=float(fallback_duration), + points=min(int(fallback_points), int(max_points)), + source="fallback_defaults", + note=loaded_model.simulation_note or "missing_uniform_time_course", + ) + + def load_summary(summary_path: Path) -> dict[str, Any]: """Load a JSON run summary if it exists. @@ -296,13 +572,10 @@ def build_summary_row( *, biomodel_id: str, sbml_path: Path, - sedml_path: Path, + sedml_path: Path | None, summary_path: Path, species_ids: list[str], - utc_duration: float, - utc_points: int, - duration: float, - points: int, + settings: SimulationSettings, exit_code: int | str, ) -> dict[str, Any]: """Build one TSV row from the pipeline summary and batch metadata. @@ -310,13 +583,10 @@ def build_summary_row( Args: biomodel_id: Curated BioModels identifier. sbml_path: Cached SBML path. - sedml_path: Cached SED-ML path. + sedml_path: Cached SED-ML path when available. summary_path: ``run_summary.json`` path for the model. species_ids: Species passed to ``data2sbml.py``. - utc_duration: Duration requested by the source SED-ML. - utc_points: Output points requested by the source SED-ML. - duration: Effective duration used for the run. - points: Effective output points used for the run. + settings: Effective simulation settings used for the run. exit_code: Process exit code or blank when no process ran. Returns: @@ -330,10 +600,13 @@ def build_summary_row( "model_id": biomodel_id, "status": nested_value(summary_payload, "status") or "pipeline_error", "exit_code": exit_code, - "utc_duration": utc_duration, - "utc_points": utc_points, - "duration": duration, - "points": points, + "start": settings.start, + "utc_duration": settings.utc_duration, + "utc_points": settings.utc_points, + "duration": settings.duration, + "points": settings.points, + "simulation_settings_source": settings.source, + "simulation_settings_note": settings.note, "original_model_equation_count": ( model_metrics.get("original_model_equation_count", "") if isinstance(model_metrics, dict) @@ -366,7 +639,7 @@ def build_summary_row( "error_message": nested_value(summary_payload, "error_message"), "species_ids": ",".join(species_ids), "sbml_path": str(sbml_path), - "sedml_path": str(sedml_path), + "sedml_path": str(sedml_path) if sedml_path is not None else "", "run_summary": str(summary_path), } @@ -401,42 +674,74 @@ def main() -> None: module = load_biocompose_run_biomodels() artifacts_dir = Path(args.artifacts_dir).resolve() download_dir = Path(args.download_dir).resolve() + local_models_dir = ( + Path(args.local_models_dir).resolve() if args.local_models_dir else None + ) summary_path = Path(args.summary_path).resolve() config_path = Path(args.config).resolve() artifacts_dir.mkdir(parents=True, exist_ok=True) download_dir.mkdir(parents=True, exist_ok=True) + if local_models_dir is not None and not local_models_dir.is_dir(): + raise ValueError( + f"--local-models-dir does not exist or is not a directory: {local_models_dir}" + ) rows: list[dict[str, Any]] = [] - biomodel_ids = select_model_ids(args, module) + biomodel_ids = ( + select_local_model_ids(args, local_models_dir) + if local_models_dir is not None + else select_model_ids(args, module) + ) for index, biomodel_id in enumerate(biomodel_ids, start=1): print(f"[{index}/{len(biomodel_ids)}] Loading {biomodel_id} ...") output_dir = artifacts_dir / biomodel_id summary_json_path = output_dir / "run_summary.json" loaded_model = None try: - metadata = module.biomodels.get_metadata(biomodel_id) - loaded_model = module.load_biomodel( - biomodel_id, - metadata, - stable_root=str(download_dir), - require_sedml=not args.download_only, - require_utc=not args.download_only, + loaded_model = ( + load_local_biomodel( + biomodel_id=biomodel_id, + local_models_dir=local_models_dir, + module=module, + require_sedml=args.require_sedml, + ) + if local_models_dir is not None + else load_remote_biomodel( + biomodel_id=biomodel_id, + module=module, + download_dir=download_dir, + require_sedml=args.require_sedml, + ) ) if args.download_only: rows.append( { "model_id": biomodel_id, - "status": "downloaded", + "status": ( + "discovered_local_files" + if local_models_dir is not None + else "downloaded" + ), "exit_code": "", - "sbml_path": loaded_model.sbml_path, - "sedml_path": loaded_model.sedml_path or "", + "simulation_settings_source": ( + "sedml_uniform_time_course" + if loaded_model.utc is not None + else "fallback_defaults" + ), + "simulation_settings_note": loaded_model.simulation_note, + "sbml_path": str(loaded_model.sbml_path), + "sedml_path": ( + str(loaded_model.sedml_path) + if loaded_model.sedml_path is not None + else "" + ), "run_summary": str(summary_json_path), } ) write_summary_rows(rows, summary_path) continue species_ids = resolve_species_ids( - Path(loaded_model.sbml_path), + loaded_model.sbml_path, args.max_species, ) except Exception as exc: @@ -452,26 +757,38 @@ def main() -> None: "status": status, "exit_code": "", "error_message": error_message, - "sbml_path": loaded_model.sbml_path if loaded_model else "", - "sedml_path": loaded_model.sedml_path if loaded_model else "", + "sbml_path": str(loaded_model.sbml_path) if loaded_model else "", + "sedml_path": ( + str(loaded_model.sedml_path) + if loaded_model and loaded_model.sedml_path is not None + else "" + ), "run_summary": str(summary_json_path), } ) write_summary_rows(rows, summary_path) continue - duration = float(loaded_model.utc.duration) - points = min(int(loaded_model.utc.number_of_points), int(args.max_points)) + settings = resolve_simulation_settings( + loaded_model=loaded_model, + max_points=args.max_points, + fallback_start=args.fallback_start, + fallback_duration=args.fallback_duration, + fallback_points=args.fallback_points, + ) print( f"[{index}/{len(biomodel_ids)}] Running {biomodel_id} with " - f"{len(species_ids)} species, duration={duration}, points={points} ..." + f"{len(species_ids)} species, start={settings.start}, " + f"duration={settings.duration}, points={settings.points}, " + f"source={settings.source} ..." ) exit_code = run_pipeline( - sbml_path=Path(loaded_model.sbml_path), + sbml_path=loaded_model.sbml_path, output_dir=output_dir, species_ids=species_ids, - duration=duration, - points=points, + start=settings.start, + duration=settings.duration, + points=settings.points, config_path=config_path, niterations=args.niterations, maxsize=args.maxsize, @@ -479,15 +796,12 @@ def main() -> None: ) rows.append( build_summary_row( - biomodel_id=biomodel_id, - sbml_path=Path(loaded_model.sbml_path), - sedml_path=Path(loaded_model.sedml_path), + biomodel_id=loaded_model.biomodel_id, + sbml_path=loaded_model.sbml_path, + sedml_path=loaded_model.sedml_path, summary_path=summary_json_path, species_ids=species_ids, - utc_duration=float(loaded_model.utc.duration), - utc_points=int(loaded_model.utc.number_of_points), - duration=duration, - points=points, + settings=settings, exit_code=exit_code, ) ) diff --git a/data2sbml.py b/data2sbml.py index e6a1157..5a7385d 100644 --- a/data2sbml.py +++ b/data2sbml.py @@ -15,9 +15,12 @@ import json import os import re +import tempfile from dataclasses import asdict, dataclass +from itertools import product from pathlib import Path import tomllib +from typing import Literal import xml.etree.ElementTree as element_tree import matplotlib.pyplot as plt @@ -37,9 +40,14 @@ DEFAULT_POINTS = 801 DEFAULT_POLYORDER = 3 DEFAULT_BAD_FIT_RMSE_THRESHOLD = 0.1 +DEFAULT_SELECTION_STRATEGY = "pysr_model_selection" +DEFAULT_SELECTION_TOP_K = 3 +DEFAULT_SELECTION_MAX_MODEL_EVALS = 256 +DEFAULT_GLOBAL_COMPLEXITY_PENALTY = 1e-3 SBML_NS = "http://www.sbml.org/sbml/level2/version4" MATHML_NS = "http://www.w3.org/1998/Math/MathML" PYSR_PARSER_TRANSFORMATIONS = standard_transformations + (convert_xor,) +VALID_PYSR_MODEL_SELECTIONS = {"accuracy", "best", "score"} FORBIDDEN_STATE_FUNCTIONS = { "acosh", "acos", @@ -141,11 +149,13 @@ class ParetoCandidate: complexity: Expression complexity reported by PySR. loss: Candidate loss reported by PySR. equation: SymPy-compatible candidate equation string. + score: Optional PySR score used by ``model_selection='best'``. """ complexity: int loss: float equation: str + score: float | None = None @dataclass @@ -163,6 +173,56 @@ class ModelParetoPoint: selections: dict[str, ParetoCandidate] +@dataclass +class EvaluatedModelCandidate: + """Store one full inferred-model candidate after trajectory evaluation. + + Args: + equations: Selected equations for all inferred species. + total_complexity: Sum of equation complexities. + total_loss: Sum of species-level losses. + rmse_metrics: Trajectory RMSE metrics for the reconstructed model. + objective_value: Optional global selection objective value. + """ + + equations: list[EquationSummary] + total_complexity: int + total_loss: float + rmse_metrics: dict[str, float] + objective_value: float | None = None + + +@dataclass +class SelectedModelResult: + """Store the selected model and metadata about the selection strategy. + + Args: + equations: Final equations selected for the inferred model. + strategy: High-level model-selection strategy identifier. + local_model_selection: Species-level PySR selection rule used for ranking. + candidate_model_count: Number of full-model combinations before truncation. + shortlisted_model_count: Number of combinations kept after truncation. + simulated_model_count: Number of full models successfully simulated. + unstable_model_count: Number of shortlisted models that failed simulation. + selected_rmse_mean: Optional mean RMSE of the winning model. + selected_objective_value: Optional global objective value of the winner. + total_complexity: Sum of equation complexities for the winning model. + total_loss: Sum of species-level losses for the winning model. + """ + + equations: list[EquationSummary] + strategy: str + local_model_selection: str + candidate_model_count: int + shortlisted_model_count: int + simulated_model_count: int + unstable_model_count: int + selected_rmse_mean: float | None + selected_objective_value: float | None + total_complexity: int + total_loss: float + + def configure_julia_environment(root_dir: Path) -> None: """Set optional Julia environment variables when a local Julia exists. @@ -767,6 +827,253 @@ def select_best_candidate(candidates: list[ParetoCandidate]) -> ParetoCandidate: ) +def normalize_model_selection(model_selection: str) -> str: + """Validate and normalize a PySR model-selection identifier. + + Args: + model_selection: Raw PySR model-selection name. + + Returns: + Lowercase validated model-selection name. + """ + + normalized_selection = model_selection.strip().lower() + if normalized_selection not in VALID_PYSR_MODEL_SELECTIONS: + valid_text = ", ".join(sorted(VALID_PYSR_MODEL_SELECTIONS)) + raise ValueError( + f"Unsupported PySR model selection '{model_selection}'. " + f"Expected one of: {valid_text}" + ) + return normalized_selection + + +def compute_candidate_scores( + candidates: list[ParetoCandidate], +) -> list[ParetoCandidate]: + """Attach PySR-compatible scores to Pareto candidates. + + Args: + candidates: Candidate equations for one species. + + Returns: + Candidates sorted by complexity with a ``score`` value populated. + """ + + scored_candidates: list[ParetoCandidate] = [] + last_loss: float | None = None + last_complexity = 0 + + for candidate in sorted( + candidates, + key=lambda item: (item.complexity, item.loss, item.equation), + ): + if candidate.score is not None: + score = float(candidate.score) + elif last_loss is None: + score = 0.0 + elif candidate.loss <= 0.0: + score = float("inf") + elif last_loss <= 0.0: + score = float("-inf") + else: + complexity_delta = candidate.complexity - last_complexity + if complexity_delta <= 0: + score = 0.0 + else: + score = float(-np.log(candidate.loss / last_loss) / complexity_delta) + + scored_candidates.append( + ParetoCandidate( + complexity=candidate.complexity, + loss=candidate.loss, + equation=candidate.equation, + score=score, + ) + ) + last_loss = candidate.loss + last_complexity = candidate.complexity + + return scored_candidates + + +def rank_candidates_by_model_selection( + candidates: list[ParetoCandidate], + model_selection: str, +) -> list[ParetoCandidate]: + """Rank safe Pareto candidates using PySR's local selection semantics. + + Args: + candidates: Candidate equations for one species. + model_selection: PySR ``model_selection`` rule. + + Returns: + Ranked candidates where index 0 is the preferred local choice. + """ + + normalized_selection = normalize_model_selection(model_selection) + scored_candidates = compute_candidate_scores(candidates) + + if normalized_selection == "accuracy": + return sorted( + scored_candidates, + key=lambda candidate: (candidate.loss, candidate.complexity, candidate.equation), + ) + + if normalized_selection == "score": + return sorted( + scored_candidates, + key=lambda candidate: ( + -(candidate.score if candidate.score is not None else float("-inf")), + candidate.loss, + candidate.complexity, + candidate.equation, + ), + ) + + threshold = 1.5 * min(candidate.loss for candidate in scored_candidates) + within_threshold = [ + candidate for candidate in scored_candidates if candidate.loss <= threshold + ] + outside_threshold = [ + candidate for candidate in scored_candidates if candidate.loss > threshold + ] + within_threshold.sort( + key=lambda candidate: ( + -(candidate.score if candidate.score is not None else float("-inf")), + candidate.loss, + candidate.complexity, + candidate.equation, + ) + ) + outside_threshold.sort( + key=lambda candidate: (candidate.loss, candidate.complexity, candidate.equation) + ) + return within_threshold + outside_threshold + + +def build_equation_summary( + target_species: str, + candidate: ParetoCandidate, + species_ids: list[str], +) -> EquationSummary: + """Build a serializable equation summary from one selected candidate. + + Args: + target_species: Species whose derivative is represented. + candidate: Selected symbolic-regression candidate. + species_ids: Allowed state-variable identifiers. + + Returns: + Equation summary with complexity, loss, and node count. + """ + + parsed_expression = parse_equation_expression(candidate.equation, species_ids) + return EquationSummary( + target_species=target_species, + expression=candidate.equation, + loss=float(candidate.loss), + complexity=int(candidate.complexity), + node_count=count_expression_nodes(parsed_expression), + ) + + +def count_model_combinations( + species_ids: list[str], + species_candidates: dict[str, list[ParetoCandidate]], +) -> int: + """Count the number of full-model combinations in a candidate grid. + + Args: + species_ids: Ordered species IDs in the model. + species_candidates: Candidate equations per species. + + Returns: + Total number of full-model combinations. + """ + + combination_count = 1 + for species_id in species_ids: + combination_count *= len(species_candidates[species_id]) + return combination_count + + +def trim_candidate_grid_to_limit( + species_ids: list[str], + species_candidates: dict[str, list[ParetoCandidate]], + max_model_evals: int, +) -> dict[str, list[ParetoCandidate]]: + """Trim a species-candidate grid until it fits within an evaluation budget. + + Args: + species_ids: Ordered species IDs in the model. + species_candidates: Ranked species candidates before truncation. + max_model_evals: Maximum number of full-model simulations to attempt. + + Returns: + Possibly truncated candidate grid that satisfies the evaluation budget. + """ + + if max_model_evals < 1: + raise ValueError("selection_max_model_evals must be at least 1.") + + trimmed_candidates = { + species_id: list(species_candidates[species_id]) for species_id in species_ids + } + while count_model_combinations(species_ids, trimmed_candidates) > max_model_evals: + reducible_species = [ + species_id + for species_id in species_ids + if len(trimmed_candidates[species_id]) > 1 + ] + if not reducible_species: + break + species_to_trim = max( + reducible_species, + key=lambda species_id: ( + len(trimmed_candidates[species_id]), + species_id, + ), + ) + trimmed_candidates[species_to_trim].pop() + + if count_model_combinations(species_ids, trimmed_candidates) > max_model_evals: + raise ValueError( + "Could not reduce the model-search grid below selection_max_model_evals." + ) + return trimmed_candidates + + +def compute_trajectory_rmse_metrics( + original_df: pd.DataFrame, + inferred_df: pd.DataFrame, + species_ids: list[str], +) -> dict[str, float]: + """Compute per-species and mean RMSE between two simulated trajectories. + + Args: + original_df: Reference trajectories from the original SBML model. + inferred_df: Trajectories produced by an inferred SBML model. + species_ids: Ordered species IDs to compare. + + Returns: + RMSE dictionary with one key per species and a ``rmse_mean`` entry. + """ + + rmse_metrics: dict[str, float] = {} + for species_id in species_ids: + reference_values = original_df[species_id].to_numpy(dtype=float) + inferred_values = inferred_df[species_id].to_numpy(dtype=float) + rmse_metrics[f"rmse_{species_id}"] = compute_rmse( + reference_values, + inferred_values, + ) + + rmse_metrics["rmse_mean"] = float( + np.mean([rmse_metrics[f"rmse_{species_id}"] for species_id in species_ids]) + ) + return rmse_metrics + + def prune_species_pareto_candidates( candidates: list[ParetoCandidate], ) -> list[ParetoCandidate]: @@ -823,17 +1130,28 @@ def load_species_pareto_candidates( f"PySR hall_of_fame file is missing required columns: {missing_text}" ) + score_column = None + for candidate_column in ("score", "Score"): + if candidate_column in pareto_df.columns: + score_column = candidate_column + break + candidates: list[ParetoCandidate] = [] - for row in pareto_df.itertuples(index=False): + for row in pareto_df.to_dict(orient="records"): candidates.append( ParetoCandidate( - complexity=int(row.Complexity), - loss=float(row.Loss), + complexity=int(row["Complexity"]), + loss=float(row["Loss"]), equation=restore_raw_pysr_equation( - expression=str(row.Equation), + expression=str(row["Equation"]), safe_symbol_locals=safe_symbol_locals, original_symbol_map=original_symbol_map, ), + score=( + float(row[score_column]) + if score_column is not None and row[score_column] is not None + else None + ), ) ) @@ -1026,11 +1344,10 @@ def infer_symbolic_equations( polyorder: int, output_dir: Path, ) -> tuple[ - list[EquationSummary], dict[str, dict[str, str | None]], dict[str, list[ParetoCandidate]], ]: - """Fit one symbolic regressor per species derivative. + """Fit one symbolic regressor per species derivative and collect candidates. Args: timecourse_df: Simulated state trajectories from the original SBML model. @@ -1041,7 +1358,7 @@ def infer_symbolic_equations( output_dir: Artifact directory for the current run. Returns: - Tuple with inferred equations, species artifact metadata, and species frontiers. + Tuple with species artifact metadata and species-level frontiers. """ time_values = timecourse_df["time"].to_numpy(dtype=float) @@ -1056,7 +1373,6 @@ def infer_symbolic_equations( pysr_output_root = output_dir / "pysr" pysr_output_root.mkdir(parents=True, exist_ok=True) - equation_summaries: list[EquationSummary] = [] species_artifacts: dict[str, dict[str, str | None]] = {} species_frontiers: dict[str, list[ParetoCandidate]] = {} for species_index, target_species in enumerate(species_ids): @@ -1098,7 +1414,6 @@ def infer_symbolic_equations( f"Rejected {len(rejection_messages)} candidates.\n{sample_rejections}" ) species_candidates = prune_species_pareto_candidates(safe_species_candidates) - selected_candidate = select_best_candidate(species_candidates) species_frontiers[target_species] = species_candidates species_pareto_csv_path = species_output_dir / "pareto_frontier.csv" @@ -1124,22 +1439,306 @@ def infer_symbolic_equations( "pareto_frontier_png": str(species_pareto_plot_path), "trajectory_plot_png": None, } - parsed_expression = parse_equation_expression( - selected_candidate.equation, - species_ids, + return species_artifacts, species_frontiers + + +def evaluate_model_candidate( + metadata: SBMLMetadata, + species_ids: list[str], + original_df: pd.DataFrame, + candidate_selection: dict[str, ParetoCandidate], + start: float, + duration: float, + points: int, + work_dir: Path, + candidate_index: int, +) -> EvaluatedModelCandidate: + """Evaluate one full inferred-model candidate by reconstructing and simulating it. + + Args: + metadata: Original SBML metadata reused by the inferred model. + species_ids: Ordered species IDs in the model. + original_df: Reference trajectories from the original SBML model. + candidate_selection: Candidate equation chosen for each species. + start: Simulation start time. + duration: Simulation duration. + points: Number of sampled time points. + work_dir: Temporary directory for transient candidate SBML files. + candidate_index: Monotonic index used to derive a unique filename. + + Returns: + Evaluated full-model candidate with RMSE metrics. + """ + + equations = [ + build_equation_summary( + target_species=species_id, + candidate=candidate_selection[species_id], + species_ids=species_ids, ) + for species_id in species_ids + ] + sbml_path = work_dir / f"candidate_{candidate_index:04d}.xml" + write_inferred_sbml(metadata=metadata, equations=equations, output_path=sbml_path) - equation_summaries.append( - EquationSummary( - target_species=target_species, - expression=selected_candidate.equation, - loss=float(selected_candidate.loss), - complexity=int(selected_candidate.complexity), - node_count=count_expression_nodes(parsed_expression), - ) + rr_inferred = load_sbml_model(sbml_path) + inferred_df = simulate_timecourse( + rr=rr_inferred, + species_ids=species_ids, + start=start, + duration=duration, + points=points, + ) + rmse_metrics = compute_trajectory_rmse_metrics( + original_df=original_df, + inferred_df=inferred_df, + species_ids=species_ids, + ) + return EvaluatedModelCandidate( + equations=equations, + total_complexity=sum( + candidate_selection[species_id].complexity for species_id in species_ids + ), + total_loss=float( + sum(candidate_selection[species_id].loss for species_id in species_ids) + ), + rmse_metrics=rmse_metrics, + ) + + +def select_local_pysr_model( + species_ids: list[str], + species_frontiers: dict[str, list[ParetoCandidate]], + local_model_selection: str, +) -> SelectedModelResult: + """Select one equation per species using PySR's local model-selection rule. + + Args: + species_ids: Ordered species IDs in the model. + species_frontiers: Safe Pareto frontier per species. + local_model_selection: PySR model-selection rule. + + Returns: + Selected model result without global reranking. + """ + + equations = [ + build_equation_summary( + target_species=species_id, + candidate=rank_candidates_by_model_selection( + species_frontiers[species_id], + local_model_selection, + )[0], + species_ids=species_ids, + ) + for species_id in species_ids + ] + return SelectedModelResult( + equations=equations, + strategy="pysr_model_selection", + local_model_selection=normalize_model_selection(local_model_selection), + candidate_model_count=1, + shortlisted_model_count=1, + simulated_model_count=0, + unstable_model_count=0, + selected_rmse_mean=None, + selected_objective_value=None, + total_complexity=sum(equation.complexity for equation in equations), + total_loss=float(sum(equation.loss for equation in equations)), + ) + + +def shortlist_species_candidates( + species_ids: list[str], + species_frontiers: dict[str, list[ParetoCandidate]], + local_model_selection: str, + top_k: int | None, + max_model_evals: int, +) -> tuple[dict[str, list[ParetoCandidate]], int, int]: + """Build a bounded candidate grid for global full-model search. + + Args: + species_ids: Ordered species IDs in the model. + species_frontiers: Safe Pareto frontier per species. + local_model_selection: PySR model-selection rule used to rank each frontier. + top_k: Optional cap per species before applying the global evaluation budget. + max_model_evals: Maximum number of full-model evaluations. + + Returns: + Tuple with the candidate grid, raw combination count, and truncated count. + """ + + if top_k is not None and top_k < 1: + raise ValueError("selection_top_k must be at least 1 when provided.") + + ranked_candidates = { + species_id: rank_candidates_by_model_selection( + species_frontiers[species_id], + local_model_selection, ) + for species_id in species_ids + } + if top_k is not None: + ranked_candidates = { + species_id: candidates[:top_k] + for species_id, candidates in ranked_candidates.items() + } + + candidate_model_count = count_model_combinations(species_ids, ranked_candidates) + shortlisted_candidates = trim_candidate_grid_to_limit( + species_ids=species_ids, + species_candidates=ranked_candidates, + max_model_evals=max_model_evals, + ) + shortlisted_model_count = count_model_combinations( + species_ids, + shortlisted_candidates, + ) + return shortlisted_candidates, candidate_model_count, shortlisted_model_count - return equation_summaries, species_artifacts, species_frontiers + +def select_global_model( + strategy: Literal["global_rmse", "global_multiobjective"], + metadata: SBMLMetadata, + species_ids: list[str], + species_frontiers: dict[str, list[ParetoCandidate]], + original_df: pd.DataFrame, + local_model_selection: str, + start: float, + duration: float, + points: int, + top_k: int | None, + max_model_evals: int, + complexity_penalty: float, +) -> SelectedModelResult: + """Select a full inferred model by evaluating reconstructed SBML simulations. + + Args: + strategy: Global selection strategy identifier. + metadata: Original SBML metadata reused by the inferred model. + species_ids: Ordered species IDs in the model. + species_frontiers: Safe Pareto frontier per species. + original_df: Reference trajectories from the original SBML model. + local_model_selection: PySR model-selection rule used to rank candidates. + start: Simulation start time. + duration: Simulation duration. + points: Number of sampled time points. + top_k: Optional per-species shortlist size before truncation. + max_model_evals: Maximum number of full-model evaluations. + complexity_penalty: Penalty applied to total complexity for the multiobjective strategy. + + Returns: + Selected model result based on the requested global strategy. + """ + + if complexity_penalty < 0.0: + raise ValueError("global_complexity_penalty must be non-negative.") + + shortlisted_candidates, candidate_model_count, shortlisted_model_count = ( + shortlist_species_candidates( + species_ids=species_ids, + species_frontiers=species_frontiers, + local_model_selection=local_model_selection, + top_k=top_k, + max_model_evals=max_model_evals, + ) + ) + evaluated_candidates: list[EvaluatedModelCandidate] = [] + unstable_model_count = 0 + normalized_local_selection = normalize_model_selection(local_model_selection) + + with tempfile.TemporaryDirectory(prefix="data2sbml_model_search_") as temp_dir: + work_dir = Path(temp_dir) + candidate_lists = [shortlisted_candidates[species_id] for species_id in species_ids] + for candidate_index, candidate_tuple in enumerate(product(*candidate_lists), start=1): + candidate_selection = dict(zip(species_ids, candidate_tuple, strict=True)) + try: + evaluated_candidate = evaluate_model_candidate( + metadata=metadata, + species_ids=species_ids, + original_df=original_df, + candidate_selection=candidate_selection, + start=start, + duration=duration, + points=points, + work_dir=work_dir, + candidate_index=candidate_index, + ) + except Exception: + unstable_model_count += 1 + continue + + if strategy == "global_multiobjective": + evaluated_candidate.objective_value = ( + evaluated_candidate.rmse_metrics["rmse_mean"] + + complexity_penalty * evaluated_candidate.total_complexity + ) + evaluated_candidates.append(evaluated_candidate) + + if not evaluated_candidates: + raise ValueError( + "No shortlisted full-model candidate produced a stable inferred simulation." + ) + + if strategy == "global_rmse": + best_candidate = min( + evaluated_candidates, + key=lambda candidate: ( + candidate.rmse_metrics["rmse_mean"], + candidate.total_complexity, + candidate.total_loss, + ), + ) + else: + best_candidate = min( + evaluated_candidates, + key=lambda candidate: ( + candidate.objective_value + if candidate.objective_value is not None + else float("inf"), + candidate.rmse_metrics["rmse_mean"], + candidate.total_complexity, + candidate.total_loss, + ), + ) + + return SelectedModelResult( + equations=best_candidate.equations, + strategy=strategy, + local_model_selection=normalized_local_selection, + candidate_model_count=candidate_model_count, + shortlisted_model_count=shortlisted_model_count, + simulated_model_count=len(evaluated_candidates), + unstable_model_count=unstable_model_count, + selected_rmse_mean=best_candidate.rmse_metrics["rmse_mean"], + selected_objective_value=best_candidate.objective_value, + total_complexity=best_candidate.total_complexity, + total_loss=best_candidate.total_loss, + ) + + +def build_selection_summary(result: SelectedModelResult) -> dict[str, object]: + """Build a JSON-serializable summary of the selected-model strategy. + + Args: + result: Final selected-model result. + + Returns: + Summary payload without duplicating the equation list. + """ + + return { + "strategy": result.strategy, + "local_model_selection": result.local_model_selection, + "candidate_model_count": result.candidate_model_count, + "shortlisted_model_count": result.shortlisted_model_count, + "simulated_model_count": result.simulated_model_count, + "unstable_model_count": result.unstable_model_count, + "selected_rmse_mean": result.selected_rmse_mean, + "selected_objective_value": result.selected_objective_value, + "total_complexity": result.total_complexity, + "total_loss": result.total_loss, + } def validate_equation_symbols( @@ -1360,7 +1959,11 @@ def plot_trajectory_comparison( RMSE metrics per species plus the mean RMSE. """ - rmse_metrics: dict[str, float] = {} + rmse_metrics = compute_trajectory_rmse_metrics( + original_df=original_df, + inferred_df=inferred_df, + species_ids=species_ids, + ) time_values = original_df["time"].to_numpy(dtype=float) figure, axes = plt.subplots( @@ -1375,8 +1978,7 @@ def plot_trajectory_comparison( for axis, species_id in zip(axes, species_ids, strict=False): reference_values = original_df[species_id].to_numpy(dtype=float) inferred_values = inferred_df[species_id].to_numpy(dtype=float) - species_rmse = compute_rmse(reference_values, inferred_values) - rmse_metrics[f"rmse_{species_id}"] = species_rmse + species_rmse = rmse_metrics[f"rmse_{species_id}"] axis.plot(time_values, reference_values, label="Original SBML", linewidth=2.0) axis.plot( @@ -1397,9 +1999,6 @@ def plot_trajectory_comparison( figure.savefig(output_path, dpi=150) plt.close(figure) - rmse_metrics["rmse_mean"] = float( - np.mean([rmse_metrics[f"rmse_{species_id}"] for species_id in species_ids]) - ) return rmse_metrics @@ -1494,6 +2093,7 @@ def build_summary_payload( metadata: SBMLMetadata | None, artifacts: dict[str, str], status: str, + selection: dict[str, object] | None = None, pysr_kwargs: dict[str, object] | None = None, species_artifacts: dict[str, dict[str, str | None]] | None = None, metrics: dict[str, float] | None = None, @@ -1511,6 +2111,7 @@ def build_summary_payload( metadata: Optional SBML metadata from the source model. artifacts: Artifact paths that exist or are expected for the run. status: High-level run status. + selection: Optional model-selection summary. pysr_kwargs: Optional PySR settings used by the run. species_artifacts: Optional nested artifact metadata per inferred species. metrics: Optional metric dictionary for completed validations. @@ -1533,6 +2134,8 @@ def build_summary_payload( } if metadata is not None: payload["metadata"] = asdict(metadata) + if selection is not None: + payload["selection"] = selection if pysr_kwargs is not None: payload["pysr_kwargs"] = pysr_kwargs if species_artifacts is not None: @@ -1639,6 +2242,46 @@ def parse_args() -> argparse.Namespace: type=int, help="Optional override for the PySR maxsize setting.", ) + parser.add_argument( + "--selection-strategy", + default=DEFAULT_SELECTION_STRATEGY, + choices=[ + "pysr_model_selection", + "global_rmse", + "global_multiobjective", + ], + help=( + "Model-selection strategy ordered from cheapest to most expensive: " + "local PySR selection only, global RMSE reranking, or global " + "RMSE-plus-complexity reranking." + ), + ) + parser.add_argument( + "--selection-top-k", + type=int, + default=None, + help=( + "Per-species shortlist size used by global selection strategies. " + f"Defaults to {DEFAULT_SELECTION_TOP_K} for global_rmse and to the " + "full safe frontier for global_multiobjective. Ignored by the local " + "PySR-only strategy." + ), + ) + parser.add_argument( + "--selection-max-model-evals", + type=int, + default=DEFAULT_SELECTION_MAX_MODEL_EVALS, + help="Maximum number of full-model combinations evaluated by global search.", + ) + parser.add_argument( + "--global-complexity-penalty", + type=float, + default=DEFAULT_GLOBAL_COMPLEXITY_PENALTY, + help=( + "Penalty multiplied by total complexity when using the " + "global_multiobjective selection strategy." + ), + ) parser.add_argument( "--bad-fit-rmse-threshold", type=float, @@ -1666,6 +2309,7 @@ def main() -> None: metadata: SBMLMetadata | None = None species_ids: list[str] = [] pysr_kwargs: dict[str, object] | None = None + selection_summary: dict[str, object] | None = None species_artifacts: dict[str, dict[str, str | None]] = {} model_metrics: dict[str, int] | None = None @@ -1690,8 +2334,11 @@ def main() -> None: if args.maxsize is not None: pysr_config["maxsize"] = args.maxsize pysr_kwargs = build_pysr_kwargs(pysr_config) + local_model_selection = normalize_model_selection( + str(pysr_kwargs.get("model_selection", "best")) + ) - equations, species_artifacts, species_frontiers = infer_symbolic_equations( + species_artifacts, species_frontiers = infer_symbolic_equations( timecourse_df=original_df, species_ids=species_ids, pysr_kwargs=pysr_kwargs, @@ -1711,6 +2358,50 @@ def main() -> None: ) artifacts["model_pareto_front_csv"] = str(model_pareto_csv_path) artifacts["model_pareto_front_png"] = str(model_pareto_plot_path) + + if args.selection_strategy == "pysr_model_selection": + selection_result = select_local_pysr_model( + species_ids=species_ids, + species_frontiers=species_frontiers, + local_model_selection=local_model_selection, + ) + elif args.selection_strategy == "global_rmse": + selection_result = select_global_model( + strategy="global_rmse", + metadata=metadata, + species_ids=species_ids, + species_frontiers=species_frontiers, + original_df=original_df, + local_model_selection=local_model_selection, + start=args.start, + duration=args.duration, + points=args.points, + top_k=( + args.selection_top_k + if args.selection_top_k is not None + else DEFAULT_SELECTION_TOP_K + ), + max_model_evals=args.selection_max_model_evals, + complexity_penalty=args.global_complexity_penalty, + ) + else: + selection_result = select_global_model( + strategy="global_multiobjective", + metadata=metadata, + species_ids=species_ids, + species_frontiers=species_frontiers, + original_df=original_df, + local_model_selection=local_model_selection, + start=args.start, + duration=args.duration, + points=args.points, + top_k=args.selection_top_k, + max_model_evals=args.selection_max_model_evals, + complexity_penalty=args.global_complexity_penalty, + ) + + equations = selection_result.equations + selection_summary = build_selection_summary(selection_result) model_metrics = build_model_metrics( sbml_path=sbml_path, species_ids=species_ids, @@ -1721,6 +2412,7 @@ def main() -> None: "input_sbml": str(sbml_path), "species_ids": species_ids, "pysr_kwargs": pysr_kwargs, + "selection": selection_summary, "model_metrics": model_metrics, "equations": [asdict(equation) for equation in equations], } @@ -1779,6 +2471,7 @@ def main() -> None: metadata=metadata, artifacts=artifacts, status=status, + selection=selection_summary, pysr_kwargs=pysr_kwargs, species_artifacts=species_artifacts, metrics=rmse_metrics, @@ -1794,6 +2487,7 @@ def main() -> None: metadata=metadata, artifacts=artifacts, status="unstable", + selection=selection_summary, pysr_kwargs=pysr_kwargs, species_artifacts=species_artifacts, model_metrics=model_metrics, @@ -1810,6 +2504,10 @@ def main() -> None: print(f"Input SBML: {sbml_path}") print(f"Output directory: {output_dir}") print(f"Species inferred: {', '.join(species_ids)}") + print( + f"Selection strategy: {selection_result.strategy} " + f"(local={selection_result.local_model_selection})" + ) for equation in equations: print( f"d{equation.target_species}/dt = {equation.expression} " @@ -1817,6 +2515,13 @@ def main() -> None: ) print(f"Inferred SBML: {inferred_sbml_path}") print(f"Model Pareto front: {model_pareto_plot_path}") + if selection_result.simulated_model_count > 0: + print( + f"Model search: {selection_result.simulated_model_count}/" + f"{selection_result.shortlisted_model_count} stable simulations" + ) + if selection_result.selected_objective_value is not None: + print(f"Selection objective: {selection_result.selected_objective_value:.6g}") print(f"Run status: {summary_payload['status']}") print(f"Comparison plot: {comparison_plot_path}") print(f"Mean RMSE: {rmse_metrics['rmse_mean']:.6g}") diff --git a/tests/test_biomodels_batch.py b/tests/test_biomodels_batch.py index 3f82b6d..9962a32 100644 --- a/tests/test_biomodels_batch.py +++ b/tests/test_biomodels_batch.py @@ -1,4 +1,4 @@ -"""Regression tests for BioModels batch identifier selection.""" +"""Regression tests for BioModels batch identifier selection and fallbacks.""" from types import SimpleNamespace @@ -53,3 +53,86 @@ def test_select_model_ids_preserves_explicit_model_list() -> None: assert model_ids == ["BIOMD1", "BIOMD2"] assert module.calls == [] + + +class DummyLocalSedmlModule: + """Minimal helper module for local directory loading tests.""" + + @staticmethod + def find_first_sedml(entry_files: list[object]) -> object | None: + for entry in entry_files: + if str(entry).endswith(".sedml"): + return entry + return None + + @staticmethod + def find_first_sbml(entry_files: list[object]) -> object | None: + for entry in entry_files: + if str(entry).endswith(".xml"): + return entry + return None + + +def test_select_local_model_ids_uses_directory_listing(tmp_path) -> None: + """Local directory mode should support filtering, offsetting, and limiting.""" + + for model_id in ("BIOMD0000000002", "MODEL0001", "BIOMD0000000001"): + (tmp_path / model_id).mkdir() + + args = SimpleNamespace( + model_id=None, + limit=1, + offset=1, + curated_only=True, + ) + + model_ids = biomodels_batch.select_local_model_ids(args, tmp_path) + + assert model_ids == ["BIOMD0000000002"] + + +def test_load_local_biomodel_accepts_missing_sedml_when_not_required(tmp_path) -> None: + """Models with only SBML should still load in local mode.""" + + model_dir = tmp_path / "BIOMD0000001092" + model_dir.mkdir() + xml_path = model_dir / "Gautam2020_RQ7_iJG408.xml" + xml_path.write_text("", encoding="utf-8") + + loaded_model = biomodels_batch.load_local_biomodel( + biomodel_id="BIOMD0000001092", + local_models_dir=tmp_path, + module=DummyLocalSedmlModule(), + require_sedml=False, + ) + + assert loaded_model.sbml_path == xml_path.resolve() + assert loaded_model.sedml_path is None + assert loaded_model.utc is None + assert loaded_model.simulation_note == "missing_sedml" + + +def test_resolve_simulation_settings_uses_fallback_without_utc() -> None: + """Missing UTC metadata should fall back to configured simulation defaults.""" + + loaded_model = biomodels_batch.LoadedBiomodel( + biomodel_id="BIOMD0000001092", + sbml_path=biomodels_batch.DEFAULT_LOCAL_MODELS_DIR / "dummy.xml", + sedml_path=None, + utc=None, + simulation_note="missing_sedml", + ) + + settings = biomodels_batch.resolve_simulation_settings( + loaded_model=loaded_model, + max_points=500, + fallback_start=2.5, + fallback_duration=60.0, + fallback_points=900, + ) + + assert settings.start == 2.5 + assert settings.duration == 60.0 + assert settings.points == 500 + assert settings.source == "fallback_defaults" + assert settings.note == "missing_sedml" diff --git a/tests/test_model_selection.py b/tests/test_model_selection.py new file mode 100644 index 0000000..7c2f41b --- /dev/null +++ b/tests/test_model_selection.py @@ -0,0 +1,218 @@ +"""Regression tests for local and global model-selection strategies.""" + +import pytest + +import data2sbml + + +def build_test_metadata() -> data2sbml.SBMLMetadata: + """Create minimal SBML metadata for model-selection tests.""" + + return data2sbml.SBMLMetadata( + model_id="demo", + model_name="demo", + compartments=[ + data2sbml.CompartmentDefinition(compartment_id="cell", size=1.0), + ], + species=[ + data2sbml.SpeciesDefinition( + species_id="x", + compartment_id="cell", + initial_concentration=1.0, + ), + data2sbml.SpeciesDefinition( + species_id="y", + compartment_id="cell", + initial_concentration=1.0, + ), + ], + ) + + +def test_rank_candidates_by_model_selection_uses_pysr_best_rule() -> None: + """PySR's ``best`` rule should prefer the highest score within the loss threshold.""" + + candidates = [ + data2sbml.ParetoCandidate(complexity=1, loss=4.0, equation="0.0"), + data2sbml.ParetoCandidate(complexity=2, loss=2.0, equation="x + y"), + data2sbml.ParetoCandidate(complexity=5, loss=1.4, equation="x*y + 1"), + ] + + ranked_candidates = data2sbml.rank_candidates_by_model_selection( + candidates, + "best", + ) + + assert ranked_candidates[0].equation == "x + y" + + +def test_shortlist_species_candidates_trims_to_model_budget() -> None: + """Global search shortlists should shrink the Cartesian product to the budget.""" + + species_frontiers = { + "x": [ + data2sbml.ParetoCandidate(complexity=1, loss=4.0, equation="x"), + data2sbml.ParetoCandidate(complexity=2, loss=2.0, equation="x + 1"), + data2sbml.ParetoCandidate(complexity=3, loss=1.5, equation="x + y"), + ], + "y": [ + data2sbml.ParetoCandidate(complexity=1, loss=5.0, equation="y"), + data2sbml.ParetoCandidate(complexity=2, loss=3.0, equation="y + 1"), + data2sbml.ParetoCandidate(complexity=3, loss=2.0, equation="x + y"), + ], + } + + shortlisted_candidates, candidate_model_count, shortlisted_model_count = ( + data2sbml.shortlist_species_candidates( + species_ids=["x", "y"], + species_frontiers=species_frontiers, + local_model_selection="best", + top_k=None, + max_model_evals=4, + ) + ) + + assert candidate_model_count == 9 + assert shortlisted_model_count == 4 + assert len(shortlisted_candidates["x"]) == 2 + assert len(shortlisted_candidates["y"]) == 2 + + +def test_select_global_model_prefers_lowest_rmse_combination( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """The global RMSE strategy should choose the full-model combination with the best RMSE.""" + + metadata = build_test_metadata() + species_frontiers = { + "x": [ + data2sbml.ParetoCandidate(complexity=1, loss=1.0, equation="x"), + data2sbml.ParetoCandidate(complexity=2, loss=0.9, equation="x + 1"), + ], + "y": [ + data2sbml.ParetoCandidate(complexity=1, loss=1.0, equation="y"), + data2sbml.ParetoCandidate(complexity=2, loss=0.9, equation="y + 1"), + ], + } + rmse_lookup = { + ("x", "y"): 0.40, + ("x", "y + 1"): 0.30, + ("x + 1", "y"): 0.08, + ("x + 1", "y + 1"): 0.12, + } + + def fake_evaluate_model_candidate(**kwargs: object) -> data2sbml.EvaluatedModelCandidate: + candidate_selection = kwargs["candidate_selection"] + assert isinstance(candidate_selection, dict) + selected_key = ( + candidate_selection["x"].equation, + candidate_selection["y"].equation, + ) + equations = [ + data2sbml.build_equation_summary( + target_species=species_id, + candidate=candidate_selection[species_id], + species_ids=["x", "y"], + ) + for species_id in ["x", "y"] + ] + return data2sbml.EvaluatedModelCandidate( + equations=equations, + total_complexity=sum(equation.complexity for equation in equations), + total_loss=float(sum(equation.loss for equation in equations)), + rmse_metrics={ + "rmse_x": rmse_lookup[selected_key], + "rmse_y": rmse_lookup[selected_key], + "rmse_mean": rmse_lookup[selected_key], + }, + ) + + monkeypatch.setattr(data2sbml, "evaluate_model_candidate", fake_evaluate_model_candidate) + + result = data2sbml.select_global_model( + strategy="global_rmse", + metadata=metadata, + species_ids=["x", "y"], + species_frontiers=species_frontiers, + original_df=data2sbml.pd.DataFrame({"time": [0.0], "x": [1.0], "y": [1.0]}), + local_model_selection="best", + start=0.0, + duration=1.0, + points=2, + top_k=2, + max_model_evals=4, + complexity_penalty=0.0, + ) + + assert [equation.expression for equation in result.equations] == ["x + 1", "y"] + assert result.selected_rmse_mean == pytest.approx(0.08) + assert result.simulated_model_count == 4 + + +def test_select_global_model_multiobjective_penalizes_complexity( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """The multiobjective strategy should trade a small RMSE gain for much lower complexity.""" + + metadata = data2sbml.SBMLMetadata( + model_id="demo", + model_name="demo", + compartments=[ + data2sbml.CompartmentDefinition(compartment_id="cell", size=1.0), + ], + species=[ + data2sbml.SpeciesDefinition( + species_id="x", + compartment_id="cell", + initial_concentration=1.0, + ), + ], + ) + species_frontiers = { + "x": [ + data2sbml.ParetoCandidate(complexity=2, loss=1.0, equation="x"), + data2sbml.ParetoCandidate(complexity=20, loss=0.8, equation="x + 1"), + ], + } + rmse_lookup = {"x": 0.21, "x + 1": 0.20} + + def fake_evaluate_model_candidate(**kwargs: object) -> data2sbml.EvaluatedModelCandidate: + candidate_selection = kwargs["candidate_selection"] + assert isinstance(candidate_selection, dict) + equation = candidate_selection["x"].equation + equations = [ + data2sbml.build_equation_summary( + target_species="x", + candidate=candidate_selection["x"], + species_ids=["x"], + ) + ] + return data2sbml.EvaluatedModelCandidate( + equations=equations, + total_complexity=equations[0].complexity, + total_loss=equations[0].loss, + rmse_metrics={ + "rmse_x": rmse_lookup[equation], + "rmse_mean": rmse_lookup[equation], + }, + ) + + monkeypatch.setattr(data2sbml, "evaluate_model_candidate", fake_evaluate_model_candidate) + + result = data2sbml.select_global_model( + strategy="global_multiobjective", + metadata=metadata, + species_ids=["x"], + species_frontiers=species_frontiers, + original_df=data2sbml.pd.DataFrame({"time": [0.0], "x": [1.0]}), + local_model_selection="best", + start=0.0, + duration=1.0, + points=2, + top_k=None, + max_model_evals=4, + complexity_penalty=0.01, + ) + + assert [equation.expression for equation in result.equations] == ["x"] + assert result.selected_objective_value == pytest.approx(0.23)