Skip to content

Commit 4c00b25

Browse files
committed
add autocheck synthetic path accessibility
1 parent 45fdf2a commit 4c00b25

4 files changed

Lines changed: 493 additions & 0 deletions

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,3 +52,4 @@ submit*.sh
5252
qual.sh
5353

5454
collectors/
55+
data/compounds

README_accessibility.md

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
# 可购性检查 CLI 使用说明
2+
3+
本说明介绍如何使用本仓库内的可购性检查工具,基于两部分:
4+
- `scripts/check_accessible.py`:单个 SMILES 可购性检查(本地库存优先,其次 PubChem)。
5+
- `inference.py`:命令行工具(CLI)。可对模型预测得到的合成路径(paths)进行筛选:若存在“全可购”路径优先保留;若不存在,则保留“不可购节点最少”的路径集合。也可先调用模型生成 paths 再筛选。
6+
7+
适用场景:你已有模型输出的 `path_string` 列表,想快速过滤出所有节点都可购的路线;或希望一步生成+筛选。
8+
9+
目录定位(默认):
10+
- 可购库存清单:`data/compounds/buyables-stock.txt`
11+
- 可视化/模型配置:`data/configs/dms_dictionary.yaml`
12+
- 模型权重目录:`data/checkpoints/`
13+
14+
注意:仅“检查模式”(只用 `--paths-file`)不需要安装 torch;只有选择“先生成再检查”时才会加载模型相关依赖。输入 paths 支持两种形式:
15+
- JSON 字符串 path_string(如 "{'smiles':'...','children':[...]}")
16+
- 已解析的字典对象路径(形如 `{smiles, children}` 的树)
17+
18+
一、快速开始(仅检查已有 paths)
19+
- 文件格式支持:
20+
- JSON 数组:元素可以是 `path_string` 字符串,或已解析好的路径字典对象
21+
- 行分隔文本:每行一个 `path_string` 字符串
22+
23+
示例:
24+
- 只检查 children(起始物料),不要求根节点(产物)可购:
25+
- `python3 inference.py --paths-file my_paths.txt --children-only --output filtered.json`
26+
- 要求路径中包含根节点在内所有节点都可购:
27+
- `python3 inference.py --paths-file my_paths.txt --include-root --output filtered.json`
28+
29+
可选参数:
30+
- `--stock-file data/compounds/buyables-stock.txt` 覆盖默认库存清单
31+
- `--sleep 0.2` PubChem 查询限速(秒),默认 0.2
32+
33+
二、生成 + 检查(需要模型与 checkpoint)
34+
示例:
35+
- `python3 inference.py --target "CNCc1cc(-c2ccccc2F)n(S(=O)(=O)c2cccnc2)c1" --model explorer --beam-size 32 --children-only --output filtered.json`
36+
- 可选:`--n-steps`, `--starting-material`, `--ckpt-dir data/checkpoints`, `--config data/configs/dms_dictionary.yaml`
37+
38+
提示:脚本会在运行时自动将 `./src` 加入 `sys.path`,可直接以源码方式调用模块;若使用自定义安装环境,确保 `directmultistep` 模块可被导入。
39+
40+
三、筛选策略与输出说明(JSON)
41+
- 筛选策略:
42+
- 若存在“整条路径所有被检查节点都可购”的路径,优先保留这些(all_access)。
43+
- 若不存在全可购路径,则保留“不可购节点数最少”的路径集合(min_inaccessible)。若并列,全部保留。
44+
- 输出字段:
45+
- `total`: 候选路径数量
46+
- `accessible`: 全可购路径数量(仅统计 all_access)
47+
- `include_root`: 是否要求根节点一并检查
48+
- `selection`: `all_access``min_inaccessible`
49+
- `min_not_accessible`: 若为 `min_inaccessible` 策略,给出最小不可购节点数;否则为 null
50+
- `paths`: 选中的路径列表(JSON 结构:每条路径为 `{smiles, children}` 的字典树,不是字符串)
51+
- `statuses`: 每个出现过的 SMILES 的判定结果,形如 `{"status": "purchasable" | "no_vendors" | "not_found" | "error", "detail": "..."}`
52+
- `path_reports`: 每条候选路径的诊断:`{"path": <JSON路径>, "not_accessible": <数量>, "not_accessible_smiles": [..]}`
53+
54+
四、判定逻辑与来源
55+
- 先查本地清单(精确匹配):`data/compounds/buyables-stock.txt`
56+
- 不在本地清单时,调用 PubChem PUG/PUG View:
57+
- SMILES → CID;若 CID == 0 或未找到,判为 `not_found`
58+
- CID → Chemical Vendors;若存在并为真,判为 `purchasable`;否则 `no_vendors`
59+
- 为减少对 PubChem 的压力,在网络查询间加入 `--sleep` 延迟。
60+
61+
五、返回码(exit code)
62+
- 0:至少有一条路径被选中(全可购或“最少不可购”)
63+
- 3:没有路径被选中
64+
- 2:输入错误或文件不存在等问题
65+
66+
六、常见用法小抄
67+
- 检查 children-only(最常用):
68+
- `python3 inference.py --paths-file my_paths.txt --children-only --output filtered.json`
69+
- 使用自定义库存清单:
70+
- `python3 inference.py --paths-file my_paths.txt --children-only --stock-file path/to/stock.txt`
71+
- 生成+检查(含根节点):
72+
- `python3 inference.py --target "<SMILES>" --model explorer --include-root --output filtered.json`
73+
- 若不存在全可购路径,输出会自动切换到 `min_inaccessible` 策略并保留不可购最少的路径集合。
74+
75+
七、局限与建议
76+
- PubChem 作为可购性代理指标对常见试剂较准,但对盐型/互变异构体/异常表示可能返回 `not_found``no_vendors`
77+
- 本地清单为精确字符串匹配;若需 SMILES 规范化后匹配,可扩展脚本进行标准化。
78+
- 批量大规模查询时适当增大 `--sleep`,避免过快触发限速。
79+
80+
相关脚本
81+
- `inference.py`:CLI 主脚本
82+
- `scripts/check_accessible.py`:单个 SMILES 判定逻辑(可独立运行)
83+
84+
问题反馈或扩展(建议)
85+
- 需要导出“不可购/未知”路径及不可购节点详细清单(CSV/JSON)?
86+
- 需要只检查叶子节点(起始原料)/只检查 children/包含根节点的不同策略?目前已支持 `--children-only``--include-root`

inference.py

Lines changed: 229 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,229 @@
1+
# Allow running without installation by adding ./src to sys.path (done lazily in main before model import)
2+
3+
from pathlib import Path
4+
from typing import Any, Dict, List, Sequence, Set, Tuple
5+
6+
import argparse
7+
import json
8+
import ast
9+
import requests
10+
import sys
11+
import time
12+
13+
# Reuse the per-SMILES checker from scripts/check_accessible.py
14+
from scripts.check_accessible import check_accessible # type: ignore
15+
16+
17+
def _iter_all_smiles(node: Dict[str, Any]) -> List[str]:
18+
smiles: List[str] = []
19+
smi = node.get("smiles")
20+
if isinstance(smi, str) and smi:
21+
smiles.append(smi)
22+
for child in node.get("children", []) or []:
23+
smiles.extend(_iter_all_smiles(child))
24+
return smiles
25+
26+
27+
def _load_stock(stock_file: Path) -> Set[str]:
28+
stock: Set[str] = set()
29+
if stock_file.is_file():
30+
with stock_file.open("r", encoding="utf-8") as f:
31+
for line in f:
32+
t = line.strip()
33+
if not t or t.startswith("#"):
34+
continue
35+
stock.add(t)
36+
return stock
37+
38+
39+
def filter_accessible_paths(
40+
paths: Sequence[Any], stock_file: Path, include_root: bool = True, sleep: float = 0.2
41+
) -> Tuple[List[Dict[str, Any]], Dict[str, Dict[str, str]], List[Dict[str, Any]]]:
42+
"""Return only those path strings where all nodes are accessible.
43+
44+
Uses local stock first, then PubChem via check_accessible(). Returns the filtered
45+
list and a map of SMILES -> {status, detail} used for the decision.
46+
"""
47+
stock = _load_stock(stock_file)
48+
session = requests.Session()
49+
status_cache: Dict[str, Dict[str, str]] = {}
50+
51+
def is_accessible_smiles(smi: str) -> bool:
52+
if smi in status_cache:
53+
return status_cache[smi]["status"] == "purchasable"
54+
if smi in stock:
55+
status_cache[smi] = {"status": "purchasable", "detail": "found in stock file"}
56+
return True
57+
status, detail = check_accessible(smi, session)
58+
# Be polite to PubChem
59+
time.sleep(sleep)
60+
status_cache[smi] = {"status": status, "detail": detail}
61+
return status == "purchasable"
62+
63+
filtered: List[Dict[str, Any]] = []
64+
reports: List[Dict[str, Any]] = []
65+
for path_item in paths:
66+
# Accept either a stringified Python/JSON literal or a dict-like object
67+
node: Dict[str, Any]
68+
if isinstance(path_item, str):
69+
try:
70+
node = ast.literal_eval(path_item)
71+
except Exception:
72+
# skip malformed path strings
73+
continue
74+
path_string = path_item
75+
elif isinstance(path_item, dict):
76+
node = path_item
77+
path_string = None # not used for output, keep node structure
78+
else:
79+
# Attempt to coerce unknown objects exposing mapping-like API
80+
try:
81+
node = dict(path_item) # type: ignore[arg-type]
82+
path_string = None
83+
except Exception:
84+
continue
85+
smiles_chain = _iter_all_smiles(node)
86+
# optionally exclude root
87+
if not include_root and smiles_chain:
88+
smiles_chain = smiles_chain[1:]
89+
ok = True
90+
bad_nodes: List[str] = []
91+
for smi in smiles_chain:
92+
if not is_accessible_smiles(smi):
93+
ok = False
94+
bad_nodes.append(smi)
95+
if ok:
96+
filtered.append(node)
97+
# Collect a per-path report
98+
reports.append({
99+
"path": node,
100+
"not_accessible": len(bad_nodes),
101+
"not_accessible_smiles": bad_nodes,
102+
})
103+
return filtered, status_cache, reports
104+
105+
106+
def _load_paths_from_file(p: Path) -> List[Any]:
107+
text = p.read_text(encoding="utf-8").strip()
108+
if not text:
109+
return []
110+
# Try JSON array first
111+
try:
112+
arr = json.loads(text)
113+
if isinstance(arr, list):
114+
# Accept lists of strings or dicts
115+
if all(isinstance(x, (str, dict)) for x in arr):
116+
return arr
117+
except Exception:
118+
pass
119+
# Fallback: newline-delimited strings
120+
return [line for line in text.splitlines() if line.strip()]
121+
122+
123+
def main() -> int:
124+
data_path = Path("./data")
125+
default_ckpt = data_path / "checkpoints"
126+
default_config = data_path / "configs" / "dms_dictionary.yaml"
127+
default_stock = data_path / "compounds" / "buyables-stock.txt"
128+
default_output = data_path / "accessible_paths_from_inference.json"
129+
130+
ap = argparse.ArgumentParser(description="Generate routes (optional) and filter paths where all nodes are purchasable.")
131+
src = ap.add_argument_group("source of paths")
132+
src.add_argument("--paths-file", type=Path, default=None, help="If provided, read candidate paths from file (JSON array or newline separated)")
133+
src.add_argument("--target", default=None, help="Target SMILES. Required if not using --paths-file")
134+
src.add_argument("--n-steps", type=int, default=None, help="Number of steps (None lets the model decide)")
135+
src.add_argument("--starting-material", default=None, help="Optional starting material SMILES")
136+
src.add_argument("--model", default="explorer", help="Model name or checkpoint; defaults to explorer")
137+
src.add_argument("--beam-size", type=int, default=32, help="Beam size (default: 32)")
138+
src.add_argument("--ckpt-dir", type=Path, default=default_ckpt, help="Checkpoint directory (default: data/checkpoints)")
139+
src.add_argument("--config", type=Path, default=default_config, help="Config path (default: data/configs/dms_dictionary.yaml)")
140+
141+
filt = ap.add_argument_group("filter options")
142+
filt.add_argument("--stock-file", type=Path, default=default_stock, help="Local purchasable SMILES list")
143+
excl = filt.add_mutually_exclusive_group()
144+
excl.add_argument("--include-root", action="store_true", help="Require root (product) to be purchasable")
145+
excl.add_argument("--children-only", action="store_true", help="Only require children (reactants) to be purchasable")
146+
filt.add_argument("--sleep", type=float, default=0.2, help="Delay between PubChem requests (default: 0.2s)")
147+
148+
ap.add_argument("--output", type=Path, default=default_output, help="Output JSON file path")
149+
150+
args = ap.parse_args()
151+
152+
# Determine include_root flag
153+
include_root = True
154+
if args.children_only:
155+
include_root = False
156+
elif args.include_root:
157+
include_root = True
158+
159+
# Load or generate candidate paths
160+
if args.paths_file is not None:
161+
if not args.paths_file.is_file():
162+
print(f"Error: paths file not found: {args.paths_file}", file=sys.stderr)
163+
return 2
164+
paths = _load_paths_from_file(args.paths_file)
165+
else:
166+
if not args.target:
167+
print("Error: --target is required when --paths-file is not provided", file=sys.stderr)
168+
return 2
169+
# Lazy import to avoid requiring torch when only checking paths
170+
# Also add ./src to sys.path if available
171+
repo_root = Path(__file__).resolve().parent
172+
src_dir = repo_root / "src"
173+
if (src_dir / "directmultistep").exists():
174+
sys.path.insert(0, str(src_dir))
175+
from directmultistep.generate import generate_routes # type: ignore
176+
paths = generate_routes(
177+
target=args.target,
178+
n_steps=args.n_steps,
179+
starting_material=args.starting_material,
180+
model=args.model,
181+
beam_size=args.beam_size,
182+
config_path=args.config,
183+
ckpt_dir=args.ckpt_dir,
184+
)
185+
186+
# Filter by accessibility
187+
filtered_paths, statuses, reports = filter_accessible_paths(
188+
paths=paths,
189+
stock_file=args.stock_file,
190+
include_root=include_root,
191+
sleep=args.sleep,
192+
)
193+
194+
selection_mode = "all_access"
195+
min_inacc = None
196+
selected_paths: List[Dict[str, Any]] = filtered_paths
197+
# If none fully accessible, pick those with minimal number of not-accessible nodes
198+
if not selected_paths:
199+
if reports:
200+
min_inacc = min(r["not_accessible"] for r in reports)
201+
selection_mode = "min_inaccessible"
202+
selected_paths = [r["path"] for r in reports if r["not_accessible"] == min_inacc]
203+
204+
args.output.parent.mkdir(parents=True, exist_ok=True)
205+
with args.output.open("w", encoding="utf-8") as f:
206+
json.dump(
207+
{
208+
"total": len(paths),
209+
"accessible": len(filtered_paths),
210+
"include_root": include_root,
211+
"selection": selection_mode,
212+
"min_not_accessible": min_inacc,
213+
"paths": selected_paths,
214+
"statuses": statuses,
215+
"path_reports": reports,
216+
},
217+
f,
218+
indent=2,
219+
ensure_ascii=True,
220+
)
221+
print(
222+
f"Candidates: {len(paths)}; fully-accessible ({'including root' if include_root else 'children only'}): {len(filtered_paths)}; "
223+
f"selected: {len(selected_paths)} (mode={selection_mode}, min_not_accessible={min_inacc}); saved: {args.output}"
224+
)
225+
return 0 if selected_paths else 3
226+
227+
228+
if __name__ == "__main__":
229+
sys.exit(main())

0 commit comments

Comments
 (0)