|
| 1 | +# Allow running without installation by adding ./src to sys.path (done lazily in main before model import) |
| 2 | + |
| 3 | +from pathlib import Path |
| 4 | +from typing import Any, Dict, List, Sequence, Set, Tuple |
| 5 | + |
| 6 | +import argparse |
| 7 | +import json |
| 8 | +import ast |
| 9 | +import requests |
| 10 | +import sys |
| 11 | +import time |
| 12 | + |
| 13 | +# Reuse the per-SMILES checker from scripts/check_accessible.py |
| 14 | +from scripts.check_accessible import check_accessible # type: ignore |
| 15 | + |
| 16 | + |
| 17 | +def _iter_all_smiles(node: Dict[str, Any]) -> List[str]: |
| 18 | + smiles: List[str] = [] |
| 19 | + smi = node.get("smiles") |
| 20 | + if isinstance(smi, str) and smi: |
| 21 | + smiles.append(smi) |
| 22 | + for child in node.get("children", []) or []: |
| 23 | + smiles.extend(_iter_all_smiles(child)) |
| 24 | + return smiles |
| 25 | + |
| 26 | + |
| 27 | +def _load_stock(stock_file: Path) -> Set[str]: |
| 28 | + stock: Set[str] = set() |
| 29 | + if stock_file.is_file(): |
| 30 | + with stock_file.open("r", encoding="utf-8") as f: |
| 31 | + for line in f: |
| 32 | + t = line.strip() |
| 33 | + if not t or t.startswith("#"): |
| 34 | + continue |
| 35 | + stock.add(t) |
| 36 | + return stock |
| 37 | + |
| 38 | + |
| 39 | +def filter_accessible_paths( |
| 40 | + paths: Sequence[Any], stock_file: Path, include_root: bool = True, sleep: float = 0.2 |
| 41 | +) -> Tuple[List[Dict[str, Any]], Dict[str, Dict[str, str]], List[Dict[str, Any]]]: |
| 42 | + """Return only those path strings where all nodes are accessible. |
| 43 | +
|
| 44 | + Uses local stock first, then PubChem via check_accessible(). Returns the filtered |
| 45 | + list and a map of SMILES -> {status, detail} used for the decision. |
| 46 | + """ |
| 47 | + stock = _load_stock(stock_file) |
| 48 | + session = requests.Session() |
| 49 | + status_cache: Dict[str, Dict[str, str]] = {} |
| 50 | + |
| 51 | + def is_accessible_smiles(smi: str) -> bool: |
| 52 | + if smi in status_cache: |
| 53 | + return status_cache[smi]["status"] == "purchasable" |
| 54 | + if smi in stock: |
| 55 | + status_cache[smi] = {"status": "purchasable", "detail": "found in stock file"} |
| 56 | + return True |
| 57 | + status, detail = check_accessible(smi, session) |
| 58 | + # Be polite to PubChem |
| 59 | + time.sleep(sleep) |
| 60 | + status_cache[smi] = {"status": status, "detail": detail} |
| 61 | + return status == "purchasable" |
| 62 | + |
| 63 | + filtered: List[Dict[str, Any]] = [] |
| 64 | + reports: List[Dict[str, Any]] = [] |
| 65 | + for path_item in paths: |
| 66 | + # Accept either a stringified Python/JSON literal or a dict-like object |
| 67 | + node: Dict[str, Any] |
| 68 | + if isinstance(path_item, str): |
| 69 | + try: |
| 70 | + node = ast.literal_eval(path_item) |
| 71 | + except Exception: |
| 72 | + # skip malformed path strings |
| 73 | + continue |
| 74 | + path_string = path_item |
| 75 | + elif isinstance(path_item, dict): |
| 76 | + node = path_item |
| 77 | + path_string = None # not used for output, keep node structure |
| 78 | + else: |
| 79 | + # Attempt to coerce unknown objects exposing mapping-like API |
| 80 | + try: |
| 81 | + node = dict(path_item) # type: ignore[arg-type] |
| 82 | + path_string = None |
| 83 | + except Exception: |
| 84 | + continue |
| 85 | + smiles_chain = _iter_all_smiles(node) |
| 86 | + # optionally exclude root |
| 87 | + if not include_root and smiles_chain: |
| 88 | + smiles_chain = smiles_chain[1:] |
| 89 | + ok = True |
| 90 | + bad_nodes: List[str] = [] |
| 91 | + for smi in smiles_chain: |
| 92 | + if not is_accessible_smiles(smi): |
| 93 | + ok = False |
| 94 | + bad_nodes.append(smi) |
| 95 | + if ok: |
| 96 | + filtered.append(node) |
| 97 | + # Collect a per-path report |
| 98 | + reports.append({ |
| 99 | + "path": node, |
| 100 | + "not_accessible": len(bad_nodes), |
| 101 | + "not_accessible_smiles": bad_nodes, |
| 102 | + }) |
| 103 | + return filtered, status_cache, reports |
| 104 | + |
| 105 | + |
| 106 | +def _load_paths_from_file(p: Path) -> List[Any]: |
| 107 | + text = p.read_text(encoding="utf-8").strip() |
| 108 | + if not text: |
| 109 | + return [] |
| 110 | + # Try JSON array first |
| 111 | + try: |
| 112 | + arr = json.loads(text) |
| 113 | + if isinstance(arr, list): |
| 114 | + # Accept lists of strings or dicts |
| 115 | + if all(isinstance(x, (str, dict)) for x in arr): |
| 116 | + return arr |
| 117 | + except Exception: |
| 118 | + pass |
| 119 | + # Fallback: newline-delimited strings |
| 120 | + return [line for line in text.splitlines() if line.strip()] |
| 121 | + |
| 122 | + |
| 123 | +def main() -> int: |
| 124 | + data_path = Path("./data") |
| 125 | + default_ckpt = data_path / "checkpoints" |
| 126 | + default_config = data_path / "configs" / "dms_dictionary.yaml" |
| 127 | + default_stock = data_path / "compounds" / "buyables-stock.txt" |
| 128 | + default_output = data_path / "accessible_paths_from_inference.json" |
| 129 | + |
| 130 | + ap = argparse.ArgumentParser(description="Generate routes (optional) and filter paths where all nodes are purchasable.") |
| 131 | + src = ap.add_argument_group("source of paths") |
| 132 | + src.add_argument("--paths-file", type=Path, default=None, help="If provided, read candidate paths from file (JSON array or newline separated)") |
| 133 | + src.add_argument("--target", default=None, help="Target SMILES. Required if not using --paths-file") |
| 134 | + src.add_argument("--n-steps", type=int, default=None, help="Number of steps (None lets the model decide)") |
| 135 | + src.add_argument("--starting-material", default=None, help="Optional starting material SMILES") |
| 136 | + src.add_argument("--model", default="explorer", help="Model name or checkpoint; defaults to explorer") |
| 137 | + src.add_argument("--beam-size", type=int, default=32, help="Beam size (default: 32)") |
| 138 | + src.add_argument("--ckpt-dir", type=Path, default=default_ckpt, help="Checkpoint directory (default: data/checkpoints)") |
| 139 | + src.add_argument("--config", type=Path, default=default_config, help="Config path (default: data/configs/dms_dictionary.yaml)") |
| 140 | + |
| 141 | + filt = ap.add_argument_group("filter options") |
| 142 | + filt.add_argument("--stock-file", type=Path, default=default_stock, help="Local purchasable SMILES list") |
| 143 | + excl = filt.add_mutually_exclusive_group() |
| 144 | + excl.add_argument("--include-root", action="store_true", help="Require root (product) to be purchasable") |
| 145 | + excl.add_argument("--children-only", action="store_true", help="Only require children (reactants) to be purchasable") |
| 146 | + filt.add_argument("--sleep", type=float, default=0.2, help="Delay between PubChem requests (default: 0.2s)") |
| 147 | + |
| 148 | + ap.add_argument("--output", type=Path, default=default_output, help="Output JSON file path") |
| 149 | + |
| 150 | + args = ap.parse_args() |
| 151 | + |
| 152 | + # Determine include_root flag |
| 153 | + include_root = True |
| 154 | + if args.children_only: |
| 155 | + include_root = False |
| 156 | + elif args.include_root: |
| 157 | + include_root = True |
| 158 | + |
| 159 | + # Load or generate candidate paths |
| 160 | + if args.paths_file is not None: |
| 161 | + if not args.paths_file.is_file(): |
| 162 | + print(f"Error: paths file not found: {args.paths_file}", file=sys.stderr) |
| 163 | + return 2 |
| 164 | + paths = _load_paths_from_file(args.paths_file) |
| 165 | + else: |
| 166 | + if not args.target: |
| 167 | + print("Error: --target is required when --paths-file is not provided", file=sys.stderr) |
| 168 | + return 2 |
| 169 | + # Lazy import to avoid requiring torch when only checking paths |
| 170 | + # Also add ./src to sys.path if available |
| 171 | + repo_root = Path(__file__).resolve().parent |
| 172 | + src_dir = repo_root / "src" |
| 173 | + if (src_dir / "directmultistep").exists(): |
| 174 | + sys.path.insert(0, str(src_dir)) |
| 175 | + from directmultistep.generate import generate_routes # type: ignore |
| 176 | + paths = generate_routes( |
| 177 | + target=args.target, |
| 178 | + n_steps=args.n_steps, |
| 179 | + starting_material=args.starting_material, |
| 180 | + model=args.model, |
| 181 | + beam_size=args.beam_size, |
| 182 | + config_path=args.config, |
| 183 | + ckpt_dir=args.ckpt_dir, |
| 184 | + ) |
| 185 | + |
| 186 | + # Filter by accessibility |
| 187 | + filtered_paths, statuses, reports = filter_accessible_paths( |
| 188 | + paths=paths, |
| 189 | + stock_file=args.stock_file, |
| 190 | + include_root=include_root, |
| 191 | + sleep=args.sleep, |
| 192 | + ) |
| 193 | + |
| 194 | + selection_mode = "all_access" |
| 195 | + min_inacc = None |
| 196 | + selected_paths: List[Dict[str, Any]] = filtered_paths |
| 197 | + # If none fully accessible, pick those with minimal number of not-accessible nodes |
| 198 | + if not selected_paths: |
| 199 | + if reports: |
| 200 | + min_inacc = min(r["not_accessible"] for r in reports) |
| 201 | + selection_mode = "min_inaccessible" |
| 202 | + selected_paths = [r["path"] for r in reports if r["not_accessible"] == min_inacc] |
| 203 | + |
| 204 | + args.output.parent.mkdir(parents=True, exist_ok=True) |
| 205 | + with args.output.open("w", encoding="utf-8") as f: |
| 206 | + json.dump( |
| 207 | + { |
| 208 | + "total": len(paths), |
| 209 | + "accessible": len(filtered_paths), |
| 210 | + "include_root": include_root, |
| 211 | + "selection": selection_mode, |
| 212 | + "min_not_accessible": min_inacc, |
| 213 | + "paths": selected_paths, |
| 214 | + "statuses": statuses, |
| 215 | + "path_reports": reports, |
| 216 | + }, |
| 217 | + f, |
| 218 | + indent=2, |
| 219 | + ensure_ascii=True, |
| 220 | + ) |
| 221 | + print( |
| 222 | + f"Candidates: {len(paths)}; fully-accessible ({'including root' if include_root else 'children only'}): {len(filtered_paths)}; " |
| 223 | + f"selected: {len(selected_paths)} (mode={selection_mode}, min_not_accessible={min_inacc}); saved: {args.output}" |
| 224 | + ) |
| 225 | + return 0 if selected_paths else 3 |
| 226 | + |
| 227 | + |
| 228 | +if __name__ == "__main__": |
| 229 | + sys.exit(main()) |
0 commit comments