Onnx scripts with verification of model structure#305
Onnx scripts with verification of model structure#305
Conversation
Greptile SummaryAdded three utility scripts for exporting and verifying ONNX models, plus 18MB of model files directly to the repository. The scripts provide functionality to:
Critical Issue:
Major Concerns:
Code Quality Issues:
Confidence Score: 2/5
Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[PyTorch Checkpoint .pt] --> B[export_model.py]
A --> C[export_onnx.py]
B --> D[Binary Weights .bin]
D --> E[pufferlib/resources/drive/]
C --> F[ONNX Model .onnx]
C --> G[ONNX Data .onnx.data]
C --> H[Test I/O .pt files]
F --> I[verify_onnx.py]
G --> I
I --> J{Validation}
J -->|Structure Check| K[onnx.checker]
J -->|Runtime Test| L[onnxruntime]
C -->|if verify=True| M[Compare Outputs]
M -->|PyTorch vs ONNX| N[Numerical Verification]
Last reviewed commit: 6f1e654 |
| def puffer_type(value): | ||
| try: | ||
| return ast.literal_eval(value) | ||
| except: | ||
| return value |
There was a problem hiding this comment.
bare except catches all exceptions including KeyboardInterrupt and SystemExit
| def puffer_type(value): | |
| try: | |
| return ast.literal_eval(value) | |
| except: | |
| return value | |
| def puffer_type(value): | |
| try: | |
| return ast.literal_eval(value) | |
| except (ValueError, SyntaxError): | |
| return value |
scripts/export_model.py
Outdated
| parser.add_argument( | ||
| "--checkpoint", | ||
| type=str, | ||
| default="/scratch/pm3881/PufferDrive/experiments/puffer_drive_ntski86s/models/model_puffer_drive_000100.pt", |
There was a problem hiding this comment.
hardcoded absolute path specific to user's scratch directory will fail for other users
There was a problem hiding this comment.
Pull request overview
Adds helper scripts to export PufferDrive PyTorch checkpoints to ONNX / flat weight binaries and to validate exported ONNX models via structural checks and runtime inference.
Changes:
- Added ONNX export script that can optionally run an ORT vs PyTorch output comparison and save
.ptI/O artifacts. - Added ONNX verification script that runs
onnx.checkerand a dummy-input ORT inference pass. - Added weight export script to dump model parameters into a single
.binfile.
Reviewed changes
Copilot reviewed 3 out of 6 changed files in this pull request and generated 6 comments.
| File | Description |
|---|---|
| scripts/verify_onnx.py | New CLI to validate ONNX structure and do a basic ORT execution smoke test with generated inputs. |
| scripts/export_onnx.py | New CLI to export a Drive(+optional LSTM wrapper) policy to ONNX and optionally compare ORT outputs vs PyTorch. |
| scripts/export_model.py | New CLI to flatten and export model weights from a checkpoint into a .bin file. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| dtype = input_meta.type | ||
|
|
||
| # Handle dynamic axes (often represented as strings or -1) | ||
| processed_shape = [] | ||
| for dim in shape: | ||
| if isinstance(dim, str) or dim is None or dim < 0: | ||
| processed_shape.append(batch_size) | ||
| else: | ||
| processed_shape.append(dim) | ||
|
|
||
| print(f" Input: {name}, Shape: {shape} -> Using: {processed_shape}, Type: {dtype}") | ||
|
|
||
| # Create random input data | ||
| if "float" in dtype: | ||
| data = np.random.randn(*processed_shape).astype(np.float32) | ||
| else: | ||
| data = np.zeros(processed_shape).astype(np.int64) # Fallback | ||
|
|
There was a problem hiding this comment.
Input dtype handling is too coarse: input_meta.type is typically a string like tensor(float16), tensor(double), tensor(int32), tensor(bool), etc. Always using float32 for any float type (and int64 for everything else) can make ORT inference fail purely due to dtype mismatch, even when the model is valid. Map ONNX Runtime input types to the correct NumPy dtype (e.g., float16/float32/float64/int32/int64/bool) and generate matching dummy arrays.
| "--checkpoint", | ||
| type=str, | ||
| default="/scratch/pm3881/PufferDrive/experiments/puffer_drive_ntski86s/models/model_puffer_drive_000100.pt", | ||
| help="Path to .pt checkpoint", | ||
| ) | ||
| parser.add_argument( | ||
| "--output", | ||
| type=str, | ||
| default="pufferlib/resources/drive/model_puffer_drive_000100.bin", | ||
| help="Output .bin file path", | ||
| ) |
There was a problem hiding this comment.
The default --checkpoint path is machine/user-specific (/scratch/...) and will break for most users running the script. Prefer no default (require the flag) or use a repo-relative default similar to scripts/export_onnx.py so the script is portable.
scripts/export_model.py
Outdated
|
|
||
| # Use valid dummy env to initialize policy | ||
| # Ensure env args/kwargs are correctly passed as expected by make() | ||
| env_args = [] |
There was a problem hiding this comment.
env_args is defined but never used in this script. Removing it will reduce confusion about whether positional env args are expected/required here.
| env_args = [] |
scripts/export_onnx.py
Outdated
|
|
||
|
|
||
| def load_config(env_name, config_dir=None): | ||
| # Minimal config loader based on pufferl.py | ||
| import configparser | ||
| import glob | ||
| from collections import defaultdict | ||
| import ast | ||
|
|
||
| if config_dir is None: | ||
| puffer_dir = os.path.dirname(os.path.realpath(pufferlib.__file__)) | ||
| else: | ||
| puffer_dir = config_dir | ||
|
|
||
| puffer_config_dir = os.path.join(puffer_dir, "config/**/*.ini") | ||
| puffer_default_config = os.path.join(puffer_dir, "config/default.ini") | ||
|
|
||
| found = False | ||
| for path in glob.glob(puffer_config_dir, recursive=True): | ||
| p = configparser.ConfigParser() | ||
| p.read([puffer_default_config, path]) | ||
| if env_name in p["base"]["env_name"].split(): | ||
| found = True | ||
| break | ||
|
|
||
| if not found: | ||
| raise ValueError(f"No config for env_name {env_name}") | ||
|
|
||
| def puffer_type(value): | ||
| try: | ||
| return ast.literal_eval(value) | ||
| except: | ||
| return value | ||
|
|
||
| args = defaultdict(dict) | ||
| for section in p.sections(): | ||
| for key in p[section]: | ||
| value = puffer_type(p[section][key]) | ||
| args[section][key] = value | ||
|
|
||
| return args | ||
|
|
||
|
|
There was a problem hiding this comment.
load_config is duplicated (also present in scripts/export_model.py). Duplicated config parsing logic tends to drift over time; consider moving it into a shared helper module under scripts/ (or an existing internal utility location) and importing it from both scripts.
| def load_config(env_name, config_dir=None): | |
| # Minimal config loader based on pufferl.py | |
| import configparser | |
| import glob | |
| from collections import defaultdict | |
| import ast | |
| if config_dir is None: | |
| puffer_dir = os.path.dirname(os.path.realpath(pufferlib.__file__)) | |
| else: | |
| puffer_dir = config_dir | |
| puffer_config_dir = os.path.join(puffer_dir, "config/**/*.ini") | |
| puffer_default_config = os.path.join(puffer_dir, "config/default.ini") | |
| found = False | |
| for path in glob.glob(puffer_config_dir, recursive=True): | |
| p = configparser.ConfigParser() | |
| p.read([puffer_default_config, path]) | |
| if env_name in p["base"]["env_name"].split(): | |
| found = True | |
| break | |
| if not found: | |
| raise ValueError(f"No config for env_name {env_name}") | |
| def puffer_type(value): | |
| try: | |
| return ast.literal_eval(value) | |
| except: | |
| return value | |
| args = defaultdict(dict) | |
| for section in p.sections(): | |
| for key in p[section]: | |
| value = puffer_type(p[section][key]) | |
| args[section][key] = value | |
| return args | |
| from scripts.export_model import load_config |
There was a problem hiding this comment.
I agree with the copilot review here?
|
Nice! I think the folder structure of this could be a bit more organized? This should probably be a folder within scripts and the test files in there also loaded into that folder |
Onnx converter
Onnx I/O and structure Verifier
Creates .pt for I/O to verify