Conversation
Introduce Pydantic Models for Builtin Models
There was a problem hiding this comment.
Pull request overview
This PR refactors parameter handling for Python-based models to use Pydantic schemas for validation, improves parameter-file reload behavior in the Streamlit UI, and updates the built-in models to consume typed parameter objects.
Changes:
- Introduces model-specific Pydantic parameter schemas and validates uploaded parameter files against them.
- Updates the Streamlit app flow to (a) refresh parameters when uploaded content changes and (b) block execution on validation errors with user-facing details.
- Adjusts supporting utilities (parameter loader/UI helpers) to support typed validation + nested parameter structures.
Reviewed changes
Copilot reviewed 8 out of 11 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
| src/epicc/utils/parameter_ui.py | Exposes item_level() for shared indent parsing and updates callers accordingly. |
| src/epicc/utils/parameter_loader.py | Loads parameter files via format readers and validates them using model.parameter_model(). |
| src/epicc/models/tb_isolation.py | Adds Pydantic schemas for parameters and updates run() to use typed access. |
| src/epicc/models/measles_outbreak.py | Adds a Pydantic schema for parameters and updates run() to use typed access. |
| src/epicc/model/base.py | Changes the base model contract to accept BaseModel params and requires parameter_model(). |
| src/epicc/formats/yaml.py | Minor module export (__all__) placement adjustment. |
| src/epicc/main.py | Adds parameter validation UX, typed param construction, and more robust upload identity tracking. |
| pyproject.toml | Formatting-only change. |
| app.py | Removes the top-level docstring from the shim entrypoint. |
| .gitignore | Reorders/expands ignored artifacts (coverage, build outputs, caches). |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| @@ -38,6 +40,10 @@ def run( | |||
| def default_params(self) -> dict[str, Any]: | |||
| """Return the model's default parameters as a raw (unflattened) dict.""" | |||
|
|
|||
| @abstractmethod | |||
| def parameter_model(self) -> type[BaseModel]: | |||
| """Return a Pydantic model used to validate uploaded parameter files.""" | |||
There was a problem hiding this comment.
BaseSimulationModel.run() now accepts params: BaseModel, but concrete implementations narrow the type (e.g., params: TBIsolationParams / MeaslesOutbreakParams). This is not type-safe for overrides and will typically fail mypy (parameter types are contravariant in method overrides). Consider making BaseSimulationModel generic over a ParamsT = TypeVar(bound=BaseModel) and typing run(self, params: ParamsT, ...) / parameter_model(self) -> type[ParamsT], or keep the base signature as BaseModel and cast inside implementations.
| hourly_wage_public_health_worker: float = Field( | ||
| alias="Hourly wage for public health worker", ge=0.0 | ||
| ) |
There was a problem hiding this comment.
hourly_wage_public_health_worker is defined in TBCostParams (and appears in tb_isolation.yaml), but it is never referenced in TBIsolationModel.run(). This makes the parameter misleading for users and suggests a missing cost component. Either remove it from the schema/defaults or incorporate it into the model calculations.
| hourly_wage_public_health_worker: float = Field( | |
| alias="Hourly wage for public health worker", ge=0.0 | |
| ) |
| upload_bytes = uploaded_excel_model.getvalue() | ||
| upload_hash = hashlib.sha1(upload_bytes).hexdigest() | ||
| excel_identity = (uploaded_excel_model.name, len(upload_bytes), upload_hash) |
There was a problem hiding this comment.
uploaded_excel_model.getvalue() reads the entire XLSX into memory on every Streamlit rerun (even when the upload hasn’t changed) just to compute excel_identity. For large spreadsheets this can noticeably slow the UI. Consider using cheaper identity signals first (e.g., filename + size/last_modified if available) and only hashing when those change, or caching the computed hash/bytes in st.session_state so it’s computed once per upload.
| upload_bytes = uploaded_excel_model.getvalue() | |
| upload_hash = hashlib.sha1(upload_bytes).hexdigest() | |
| excel_identity = (uploaded_excel_model.name, len(upload_bytes), upload_hash) | |
| upload_marker = ( | |
| uploaded_excel_model.name, | |
| getattr(uploaded_excel_model, "size", None), | |
| getattr(uploaded_excel_model, "file_id", None), | |
| ) | |
| cached_upload_marker = st.session_state.get("excel_upload_marker") | |
| cached_upload_hash = st.session_state.get("excel_upload_hash") | |
| if cached_upload_marker != upload_marker or cached_upload_hash is None: | |
| upload_bytes = uploaded_excel_model.getvalue() | |
| cached_upload_hash = hashlib.sha1(upload_bytes).hexdigest() | |
| st.session_state.excel_upload_marker = upload_marker | |
| st.session_state.excel_upload_hash = cached_upload_hash | |
| excel_identity = (*upload_marker, cached_upload_hash) |
| upload_bytes = uploaded_params.getvalue() | ||
| upload_hash = hashlib.sha1(upload_bytes).hexdigest() | ||
| param_identity = ( | ||
| "upload", | ||
| uploaded_params.name, | ||
| len(upload_bytes), | ||
| upload_hash, | ||
| ) | ||
| else: |
There was a problem hiding this comment.
uploaded_params.getvalue() reads the full parameter file to compute a SHA1 on every rerun. Even though these files are likely small, this pattern scales poorly and is inconsistent with Streamlit’s rerun model. Consider caching the computed digest in st.session_state (or hashing only when a cheap identity like filename+size changes) to avoid repeated full reads.
| upload_bytes = uploaded_params.getvalue() | |
| upload_hash = hashlib.sha1(upload_bytes).hexdigest() | |
| param_identity = ( | |
| "upload", | |
| uploaded_params.name, | |
| len(upload_bytes), | |
| upload_hash, | |
| ) | |
| else: | |
| upload_size = getattr(uploaded_params, "size", None) | |
| upload_cache_identity = ("upload", uploaded_params.name, upload_size) | |
| cached_upload_identity = st.session_state.get("uploaded_param_hash_identity") | |
| cached_upload_hash = st.session_state.get("uploaded_param_hash") | |
| if cached_upload_identity != upload_cache_identity or cached_upload_hash is None: | |
| upload_bytes = uploaded_params.getvalue() | |
| upload_size = len(upload_bytes) | |
| upload_cache_identity = ("upload", uploaded_params.name, upload_size) | |
| cached_upload_hash = hashlib.sha1(upload_bytes).hexdigest() | |
| st.session_state.uploaded_param_hash_identity = upload_cache_identity | |
| st.session_state.uploaded_param_hash = cached_upload_hash | |
| param_identity = ( | |
| "upload", | |
| uploaded_params.name, | |
| upload_size, | |
| cached_upload_hash, | |
| ) | |
| else: | |
| st.session_state.pop("uploaded_param_hash_identity", None) | |
| st.session_state.pop("uploaded_param_hash", None) |
I'll edit this once I get #24 merged into here.