Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
*.egg-info/
__pycache__/
.pytest_cache/
.coverage
.DS_Store
.mypy_cache/
.pytest_cache/
.venv/
.vscode/
dist/
*.py[cod]
.DS_Store
*.egg
.coverage
__pycache__/
*.bak
*.egg
*.egg-info/
*.py[cod]
build/
dist/
6 changes: 0 additions & 6 deletions app.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,3 @@
"""Top-level Streamlit entrypoint.

This shim makes the app runnable from the repository root without requiring
manual PYTHONPATH tweaks.
"""

import sys
from pathlib import Path

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,4 @@ title = "EpiCON Cost Calculator"
mount_dirs = ["src"]
text_suffixes = [".py", ".yaml", ".yml", ".css", ".js", ".html"]
css_url = "https://cdn.jsdelivr.net/npm/@stlite/browser@0.85.1/build/stlite.css"
js_url = "https://cdn.jsdelivr.net/npm/@stlite/browser@0.85.1/build/stlite.js"
js_url = "https://cdn.jsdelivr.net/npm/@stlite/browser@0.85.1/build/stlite.js"
197 changes: 166 additions & 31 deletions src/epicc/__main__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import hashlib
import importlib.resources
import re
from typing import Any

import streamlit as st
from pydantic import BaseModel, ValidationError

from epicc.config import CONFIG
from epicc.formats import VALID_PARAMETER_SUFFIXES
Expand All @@ -14,6 +17,7 @@
from epicc.utils.model_loader import get_built_in_models
from epicc.utils.parameter_loader import load_model_params
from epicc.utils.parameter_ui import (
item_level,
render_parameters_with_indent,
reset_parameters_to_defaults,
)
Expand Down Expand Up @@ -50,11 +54,17 @@ def _render_excel_parameter_inputs(
st.sidebar.info("Upload an Excel model file to edit parameters.")
return params, label_overrides

uploaded_excel_name = uploaded_excel_model.name
if st.session_state.get("excel_active_name") != uploaded_excel_name:
st.session_state.excel_active_name = uploaded_excel_name
upload_bytes = uploaded_excel_model.getvalue()
upload_hash = hashlib.sha1(upload_bytes).hexdigest()
excel_identity = (uploaded_excel_model.name, len(upload_bytes), upload_hash)
Comment on lines +57 to +59
Copy link

Copilot AI Apr 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

uploaded_excel_model.getvalue() reads the entire XLSX into memory on every Streamlit rerun (even when the upload hasn’t changed) just to compute excel_identity. For large spreadsheets this can noticeably slow the UI. Consider using cheaper identity signals first (e.g., filename + size/last_modified if available) and only hashing when those change, or caching the computed hash/bytes in st.session_state so it’s computed once per upload.

Suggested change
upload_bytes = uploaded_excel_model.getvalue()
upload_hash = hashlib.sha1(upload_bytes).hexdigest()
excel_identity = (uploaded_excel_model.name, len(upload_bytes), upload_hash)
upload_marker = (
uploaded_excel_model.name,
getattr(uploaded_excel_model, "size", None),
getattr(uploaded_excel_model, "file_id", None),
)
cached_upload_marker = st.session_state.get("excel_upload_marker")
cached_upload_hash = st.session_state.get("excel_upload_hash")
if cached_upload_marker != upload_marker or cached_upload_hash is None:
upload_bytes = uploaded_excel_model.getvalue()
cached_upload_hash = hashlib.sha1(upload_bytes).hexdigest()
st.session_state.excel_upload_marker = upload_marker
st.session_state.excel_upload_hash = cached_upload_hash
excel_identity = (*upload_marker, cached_upload_hash)

Copilot uses AI. Check for mistakes.
should_refresh_params = False
if st.session_state.get("excel_active_identity") != excel_identity:
st.session_state.excel_active_identity = excel_identity
st.session_state.params = {}
params = st.session_state.params
should_refresh_params = True

uploaded_excel_name = uploaded_excel_model.name

editable_defaults, _ = load_excel_params_defaults_with_computed(
uploaded_excel_model, sheet_name=None, start_row=3
Expand All @@ -66,6 +76,9 @@ def handle_reset_excel() -> None:
for col_letter, default_text in current_headers.items():
st.session_state[f"label_override_{col_letter}"] = default_text

if should_refresh_params:
handle_reset_excel()

st.sidebar.button("Reset Parameters", on_click=handle_reset_excel)

if current_headers:
Expand Down Expand Up @@ -96,34 +109,53 @@ def _render_python_parameter_inputs(
model: BaseSimulationModel,
model_key: str,
params: dict[str, Any],
) -> tuple[dict[str, Any], dict[str, str]]:
) -> tuple[dict[str, Any], dict[str, str], dict[str, Any], bool]:
label_overrides: dict[str, str] = {}

sorted_suffixes = sorted(VALID_PARAMETER_SUFFIXES)
uploaded_params = st.sidebar.file_uploader(
"Optional parameter override file",
"Optional parameter file",
type=sorted_suffixes,
help="If omitted, model defaults are used.",
accept_multiple_files=False,
)

param_identity = (
"upload" if uploaded_params else "default",
uploaded_params.name if uploaded_params else None,
)
if uploaded_params:
upload_bytes = uploaded_params.getvalue()
upload_hash = hashlib.sha1(upload_bytes).hexdigest()
param_identity = (
"upload",
uploaded_params.name,
len(upload_bytes),
upload_hash,
)
else:
Comment on lines +124 to +132
Copy link

Copilot AI Apr 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

uploaded_params.getvalue() reads the full parameter file to compute a SHA1 on every rerun. Even though these files are likely small, this pattern scales poorly and is inconsistent with Streamlit’s rerun model. Consider caching the computed digest in st.session_state (or hashing only when a cheap identity like filename+size changes) to avoid repeated full reads.

Suggested change
upload_bytes = uploaded_params.getvalue()
upload_hash = hashlib.sha1(upload_bytes).hexdigest()
param_identity = (
"upload",
uploaded_params.name,
len(upload_bytes),
upload_hash,
)
else:
upload_size = getattr(uploaded_params, "size", None)
upload_cache_identity = ("upload", uploaded_params.name, upload_size)
cached_upload_identity = st.session_state.get("uploaded_param_hash_identity")
cached_upload_hash = st.session_state.get("uploaded_param_hash")
if cached_upload_identity != upload_cache_identity or cached_upload_hash is None:
upload_bytes = uploaded_params.getvalue()
upload_size = len(upload_bytes)
upload_cache_identity = ("upload", uploaded_params.name, upload_size)
cached_upload_hash = hashlib.sha1(upload_bytes).hexdigest()
st.session_state.uploaded_param_hash_identity = upload_cache_identity
st.session_state.uploaded_param_hash = cached_upload_hash
param_identity = (
"upload",
uploaded_params.name,
upload_size,
cached_upload_hash,
)
else:
st.session_state.pop("uploaded_param_hash_identity", None)
st.session_state.pop("uploaded_param_hash", None)

Copilot uses AI. Check for mistakes.
param_identity = ("default", None, 0, None)
should_refresh_params = False
if st.session_state.get("active_param_identity") != param_identity:
st.session_state.active_param_identity = param_identity
st.session_state.params = {}
params = st.session_state.params
should_refresh_params = True

model_defaults = load_model_params(
model,
uploaded_params=uploaded_params or None,
uploaded_name=uploaded_params.name if uploaded_params else None,
)
try:
model_defaults = load_model_params(
model,
uploaded_params=uploaded_params or None,
uploaded_name=uploaded_params.name if uploaded_params else None,
)
except ValidationError as exc:
_render_validation_error_details(model.human_name(), exc, sidebar=True)
return params, label_overrides, {}, True
except ValueError as exc:
st.sidebar.error(
f"Could not read parameter file for {model.human_name()}: {exc}"
)
return params, label_overrides, {}, True

if not model_defaults:
st.sidebar.info("No default parameters defined for this model.")
return params, label_overrides
return params, label_overrides, {}, True

current_headers = model.scenario_labels

Expand All @@ -135,6 +167,9 @@ def handle_reset_python() -> None:
for key, default_text in current_headers.items():
st.session_state[f"py_label_{model_key}_{key}"] = default_text

if should_refresh_params:
handle_reset_python()

st.sidebar.button("Reset Parameters", on_click=handle_reset_python)

if current_headers:
Expand All @@ -156,7 +191,92 @@ def handle_reset_python() -> None:
)

render_parameters_with_indent(model_defaults, params, model_id=model_key)
return params, label_overrides
return params, label_overrides, model_defaults, False


def _unflatten_indented_params(flat_params: dict[str, Any]) -> dict[str, Any]:
root: dict[str, Any] = {}
stack: list[dict[str, Any]] = [root]

for raw_key, value in flat_params.items():
level = item_level(raw_key)
label = raw_key.strip()

while len(stack) > level + 1:
stack.pop()

parent = stack[-1]
if value is None:
node: dict[str, Any] = {}
parent[label] = node
stack.append(node)
continue

parent[label] = value

return root


def _merge_sidebar_values(
nested_defaults: dict[str, Any], params: dict[str, Any]
) -> dict[str, Any]:
merged: dict[str, Any] = {}
for key, value in nested_defaults.items():
if isinstance(value, dict):
merged[key] = _merge_sidebar_values(value, params)
continue

merged[key] = params.get(key, value)

return merged


def _build_typed_params(
model: BaseSimulationModel,
model_defaults_flat: dict[str, Any],
params: dict[str, Any],
) -> BaseModel:
nested_defaults = _unflatten_indented_params(model_defaults_flat)
payload = _merge_sidebar_values(nested_defaults, params)
return model.parameter_model().model_validate(payload)


def _render_validation_error_details(
model_name: str, exc: ValidationError, sidebar: bool
) -> None:
target = st.sidebar if sidebar else st
issues = exc.errors()
issue_count = len(issues)
target.error(f"Parameters do not match {model_name} schema ({issue_count} issues).")

details = target.expander("Validation details", expanded=False)
with details:
preview_count = 8
for issue in issues[:preview_count]:
loc_parts = issue.get("loc", [])
path = " > ".join(str(p) for p in loc_parts) if loc_parts else "(root)"
msg = issue.get("msg", "Invalid value")
st.write(f"- {path}: {msg}")

if issue_count > preview_count:
st.caption(f"...and {issue_count - preview_count} more.")

safe_model_name = re.sub(r"[^a-z0-9]+", "_", model_name.lower()).strip("_")
full_details = exc.json(indent=2)
detail_digest = hashlib.sha1(full_details.encode("utf-8")).hexdigest()[:10]
st.text_area(
"Full details (copyable)",
value=full_details,
height=180,
key=f"{safe_model_name}_{'sidebar' if sidebar else 'main'}_validation_text_{detail_digest}",
)
st.download_button(
"Download full error details",
data=full_details,
file_name=f"{safe_model_name}_validation_error.json",
mime="application/json",
key=f"{safe_model_name}_{'sidebar' if sidebar else 'main'}_validation_download_{detail_digest}",
)


def _run_excel_simulation(
Expand All @@ -183,25 +303,18 @@ def _run_excel_simulation(
def _run_python_simulation(
selected_label: str,
model: BaseSimulationModel,
params: dict[str, Any],
typed_params: BaseModel,
label_overrides: dict[str, str],
) -> None:
with st.spinner(f"Running {selected_label}..."):
st.title(model.model_title or CONFIG.app.title)
st.write(model.model_description or CONFIG.app.description)
results = model.run(params, label_overrides=label_overrides)
results = model.run(typed_params, label_overrides=label_overrides)
render_sections(model.build_sections(results))


st.set_page_config(
page_title="EpiCON Cost Calculator",
layout="wide",
initial_sidebar_state="expanded",
)

_load_styles()

st.sidebar.title("EpiCON Cost Calculator")
st.sidebar.header("Simulation Controls")

built_in_models = get_built_in_models()
Expand All @@ -218,22 +331,44 @@ def _run_python_simulation(

st.sidebar.subheader("Input Parameters")

has_input_errors = False
typed_params: BaseModel | None = None

if is_excel_model:
params, label_overrides = _render_excel_parameter_inputs(params)
model_defaults_flat: dict[str, Any] = {}
else:
params, label_overrides = _render_python_parameter_inputs(
model_registry[selected_label],
model_key,
params,
params, label_overrides, model_defaults_flat, has_input_errors = (
_render_python_parameter_inputs(
model_registry[selected_label],
model_key,
params,
)
)

if not st.sidebar.button("Run Simulation"):
if not has_input_errors:
try:
typed_params = _build_typed_params(
model_registry[selected_label], model_defaults_flat, params
)
except ValidationError as exc:
_render_validation_error_details(selected_label, exc, sidebar=True)
has_input_errors = True

if not st.sidebar.button("Run Simulation", disabled=has_input_errors):
st.stop()

if is_excel_model:
_run_excel_simulation(params, label_overrides)
st.stop()

if typed_params is None:
st.error("Cannot run simulation until parameter validation errors are fixed.")
st.stop()

_run_python_simulation(
selected_label, model_registry[selected_label], params, label_overrides
selected_label,
model_registry[selected_label],
typed_params,
label_overrides,
)
6 changes: 3 additions & 3 deletions src/epicc/formats/yaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,13 @@ def write_template(self, model: BaseModel) -> bytes:
return self.write(model.model_dump())


__all__ = ["YAMLFormat"]


def _merge_mapping(target: CommentedMap, updates: dict[str, Any]) -> None:
"""Recursively merge plain updates into a CommentedMap template."""
for key, value in updates.items():
if isinstance(value, dict) and isinstance(target.get(key), CommentedMap):
_merge_mapping(target[key], value)
else:
target[key] = value


__all__ = ["YAMLFormat"]
8 changes: 7 additions & 1 deletion src/epicc/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from abc import ABC, abstractmethod
from typing import Any

from pydantic import BaseModel


class BaseSimulationModel(ABC):
"""Abstract contract for Python-defined simulation models."""
Expand All @@ -29,7 +31,7 @@ def scenario_labels(self) -> dict[str, str]:
@abstractmethod
def run(
self,
params: dict[str, Any],
params: BaseModel,
label_overrides: dict[str, str] | None = None,
) -> dict[str, Any]:
"""Run the model and return result payload for rendering."""
Expand All @@ -38,6 +40,10 @@ def run(
def default_params(self) -> dict[str, Any]:
"""Return the model's default parameters as a raw (unflattened) dict."""

@abstractmethod
def parameter_model(self) -> type[BaseModel]:
"""Return a Pydantic model used to validate uploaded parameter files."""
Comment on lines 31 to +45
Copy link

Copilot AI Apr 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BaseSimulationModel.run() now accepts params: BaseModel, but concrete implementations narrow the type (e.g., params: TBIsolationParams / MeaslesOutbreakParams). This is not type-safe for overrides and will typically fail mypy (parameter types are contravariant in method overrides). Consider making BaseSimulationModel generic over a ParamsT = TypeVar(bound=BaseModel) and typing run(self, params: ParamsT, ...) / parameter_model(self) -> type[ParamsT], or keep the base signature as BaseModel and cast inside implementations.

Copilot uses AI. Check for mistakes.

@abstractmethod
def build_sections(self, results: dict[str, Any]) -> list[dict[str, Any]]:
"""Transform run results into section payloads for UI rendering."""
Loading
Loading