diff --git a/cli/build.rs b/cli/build.rs new file mode 100644 index 00000000..cd7754ba --- /dev/null +++ b/cli/build.rs @@ -0,0 +1,20 @@ +use std::process::Command; + +fn main() { + // Stamp the build with a short git commit so `mfs --version` can answer + // "which build is this" without ps-aux archaeology across install paths + // (cargo install / uv tool / a worktree's own target dir can all differ). + // Falls back to "unknown" outside a git checkout (e.g. a source tarball). + let sha = Command::new("git") + .args(["rev-parse", "--short", "HEAD"]) + .output() + .ok() + .filter(|o| o.status.success()) + .and_then(|o| String::from_utf8(o.stdout).ok()) + .map(|s| s.trim().to_string()) + .filter(|s| !s.is_empty()) + .unwrap_or_else(|| "unknown".to_string()); + println!("cargo:rustc-env=MFS_GIT_SHA={sha}"); + println!("cargo:rerun-if-changed=../.git/HEAD"); + println!("cargo:rerun-if-changed=../.git/index"); +} diff --git a/cli/src/main.rs b/cli/src/main.rs index e092a162..cb6123f2 100644 --- a/cli/src/main.rs +++ b/cli/src/main.rs @@ -6,7 +6,9 @@ use clap::{Parser, Subcommand}; use serde::{Deserialize, Serialize}; use serde_json::Value; use std::collections::BTreeMap; +use std::net::TcpStream; use std::path::PathBuf; +use std::time::{Duration, Instant}; type CliResult = Result; @@ -117,7 +119,11 @@ impl std::fmt::Display for CliError { } #[derive(Parser)] -#[command(name = "mfs", version, about = "Multi-source File-like Search")] +#[command( + name = "mfs", + version = concat!(env!("CARGO_PKG_VERSION"), " (", env!("MFS_GIT_SHA"), ")"), + about = "Multi-source File-like Search" +)] struct Cli { #[command(subcommand)] cmd: Cmd, @@ -141,10 +147,10 @@ enum Cmd { #[arg(long, visible_alias = "full")] force_index: bool, /// Bundle + upload the tree to the server even on the same host (no shared fs) - #[arg(long)] + #[arg(long, conflicts_with = "no_upload")] upload: bool, /// Re-upload every file (skip the manifest diff) and force a full re-index - #[arg(long)] + #[arg(long, conflicts_with = "no_upload")] force_upload: bool, /// Never upload; have the server read the path itself (shared fs) #[arg(long)] @@ -188,7 +194,7 @@ enum Cmd { /// Line range, 1-based half-open: `start:end` returns lines start..end-1 /// (e.g. `--range 1:11` = first 10 lines). Matches `locator.lines` from /// search hits. - #[arg(long)] + #[arg(long, allow_hyphen_values = true)] range: Option, #[arg(long)] meta: bool, @@ -332,7 +338,10 @@ enum ServeAction { bind: String, }, /// Is the local mfs-server running? - Status, + Status { + #[arg(long, default_value = "127.0.0.1:13619")] + bind: String, + }, /// Tail the local server log Logs, } @@ -420,11 +429,38 @@ fn is_remote(base: &str) -> bool { !(base.contains("127.0.0.1") || base.contains("localhost") || base.contains("[::1]")) } +/// Extract the underlying filesystem path from a `file:///abs` or `file://local/abs` +/// target, mirroring the slicing the server's FilePlugin.derive_target does for the +/// same two forms. Returns None for a bare path or any other file:// identity (e.g. +/// the file:// upload identity), which callers fall back to handling as-is. +fn local_fs_path_from_target(target: &str) -> Option { + if target.starts_with("file:///") { + return Some(target["file://".len()..].to_string()); + } + if target.starts_with("file://local/") { + return Some(target["file://local".len()..].to_string()); + } + None +} + fn remote_path(base: &str, path: &str) -> CliResult { + // As with `add`'s target (see local_fs_path_from_target), a `file:///abs` or + // `file://local/abs` spelling of a local path is never itself a literal path on + // disk -- strip it before canonicalizing so path-scoping (search/grep/ls/cat/...) + // resolves it the same way the server's connector-matching does. + let fs_path = local_fs_path_from_target(path); + let canon_input = fs_path.as_deref().unwrap_or(path); if is_remote(base) { - if let Ok(abs) = std::fs::canonicalize(path) { + if let Ok(abs) = std::fs::canonicalize(canon_input) { return Ok(format!("file://{}{}", client_id()?, abs.to_string_lossy())); } + } else if fs_path.is_some() { + // Rewrite to the canonical file://local identity so a URI-spelled path + // matches the connector's registered root_uri, which the server compares by + // raw string prefix for "://"-bearing scopes (no scheme normalization there). + if let Ok(abs) = std::fs::canonicalize(canon_input) { + return Ok(format!("file://local{}", abs.to_string_lossy())); + } } Ok(path.to_string()) } @@ -468,7 +504,9 @@ fn resolve_path_arg( base: &str, path: &str, ) -> CliResult { - if let Ok(abs) = std::fs::canonicalize(path) { + let fs_path = local_fs_path_from_target(path); + let canon_input = fs_path.as_deref().unwrap_or(path); + if let Ok(abs) = std::fs::canonicalize(canon_input) { let abs_path = abs.to_string_lossy().to_string(); if let Ok(status) = get(client, &format!("{base}/v1/status"), &[]) { if let Some(mapped) = uploaded_local_path_from_status(&status, &client_id()?, &abs_path) @@ -561,6 +599,23 @@ fn profile_url_from_cfg(cfg: &ClientConfig) -> CliResult> { } } +// reqwest's request builder errors out once the built URL/request-line exceeds an +// internal length limit, and that error string echoes the whole value back -- +// dumping the entire (potentially huge) query into stderr/logs. Reject oversized +// values up front with a short, actionable message instead. +const MAX_QUERY_CHARS: usize = 8000; + +fn check_query_length(label: &str, code: &str, value: &str) -> CliResult<()> { + let len = value.chars().count(); + if len > MAX_QUERY_CHARS { + return Err(CliError::new( + code, + format!("{label} too long ({len} chars, max {MAX_QUERY_CHARS})"), + )); + } + Ok(()) +} + fn base_url() -> CliResult { if let Ok(u) = std::env::var("MFS_API_URL") { if !u.is_empty() { @@ -610,13 +665,18 @@ fn run(cli: &Cli, client: &reqwest::blocking::Client, base: &str) -> CliResult<( no_upload, yes, } => { - let is_local = std::path::Path::new(target).exists(); + // A file:///abs or file://local/abs target is exactly as local as the bare + // path it encodes -- resolve it the same way the server's FilePlugin does + // before checking existence, so spelling a local target as a URI doesn't + // wrongly trip the external-connector cost-estimate prompt below. + let fs_path = local_fs_path_from_target(target); + let is_local = std::path::Path::new(fs_path.as_deref().unwrap_or(target)).exists(); // Make a bare/relative local path absolute CLIENT-side before sending: a // loopback server resolves a relative path against its OWN cwd (not the user's), // so `mfs add ./repo` would 500 with a server-side FileNotFoundError. Canonicalizing // to the stable file://local identity also keeps search/cat/remove consistent. let canon_target: String = if is_local { - std::fs::canonicalize(target) + std::fs::canonicalize(fs_path.as_deref().unwrap_or(target)) .map(|p| p.to_string_lossy().into_owned()) .unwrap_or_else(|_| target.clone()) } else { @@ -703,6 +763,7 @@ fn run(cli: &Cli, client: &reqwest::blocking::Client, base: &str) -> CliResult<( kind, collapse, } => { + check_query_length("query", "query_too_long", query)?; if path.is_none() && !all { return Err( "specify a path to scope the search, or --all for the whole namespace".into(), @@ -747,6 +808,7 @@ fn run(cli: &Cli, client: &reqwest::blocking::Client, base: &str) -> CliResult<( } } Cmd::Grep { pattern, path } => { + check_query_length("pattern", "pattern_too_long", pattern)?; let v = get( client, &format!("{base}/v1/grep"), @@ -826,7 +888,7 @@ fn run(cli: &Cli, client: &reqwest::blocking::Client, base: &str) -> CliResult<( if *meta { println!("{v}"); } else { - println!("{}", v["content"].as_str().unwrap_or("")); + print!("{}", v["content"].as_str().unwrap_or("")); } } Cmd::Head { path, lines } => { @@ -841,7 +903,7 @@ fn run(cli: &Cli, client: &reqwest::blocking::Client, base: &str) -> CliResult<( if cli.json { println!("{v}"); } else { - println!("{}", v["content"].as_str().unwrap_or("")); + print!("{}", v["content"].as_str().unwrap_or("")); } } Cmd::Tail { path, lines } => { @@ -856,7 +918,7 @@ fn run(cli: &Cli, client: &reqwest::blocking::Client, base: &str) -> CliResult<( if cli.json { println!("{v}"); } else { - println!("{}", v["content"].as_str().unwrap_or("")); + print!("{}", v["content"].as_str().unwrap_or("")); } } Cmd::Export { path, out } => { @@ -891,12 +953,20 @@ fn run(cli: &Cli, client: &reqwest::blocking::Client, base: &str) -> CliResult<( return Ok(()); } for j in v.as_array().unwrap_or(&vec![]) { - println!( + let status = j["status"].as_str().unwrap_or("?"); + print!( "{:8} {:10} {}", - j["status"].as_str().unwrap_or("?"), + status, j["op_kind"].as_str().unwrap_or("?"), j["id"].as_str().unwrap_or("?") ); + if status == "failed" { + if let Some(err) = j["error"].as_str() { + let snippet: String = err.chars().take(80).collect(); + print!(" — {snippet}"); + } + } + println!(); } } JobAction::Show { job_id } => { @@ -1424,6 +1494,19 @@ where Ok(expanded) } +fn validate_profile_url(url: &str) -> CliResult<()> { + let parsed = reqwest::Url::parse(url).map_err(|_| { + format!("invalid profile URL '{url}': must be a valid http:// or https:// URL") + })?; + if parsed.scheme() != "http" && parsed.scheme() != "https" { + return Err(format!( + "invalid profile URL '{url}': must be a valid http:// or https:// URL" + ) + .into()); + } + Ok(()) +} + fn profile_cmd(action: &ProfileAction, json: bool) -> CliResult<()> { let mut cfg = load_client_cfg()?; match action { @@ -1431,6 +1514,7 @@ fn profile_cmd(action: &ProfileAction, json: bool) -> CliResult<()> { println!("{}", profile_list_output(&cfg, &base_url()?, json)); } ProfileAction::Add { name, url, token } => { + validate_profile_url(url)?; cfg.profiles.insert( name.clone(), Profile { @@ -1529,21 +1613,60 @@ fn serve_cmd(action: &ServeAction) -> Result<(), String> { return Ok(()); } } + // Pre-flight: a pidfile only tracks processes this CLI itself + // launched. If something not tracked by the pidfile (started by + // hand, e.g. `uv run mfs-server run`) already holds `bind`, don't + // spawn a duplicate — it would sit for however long this + // build's startup takes (Milvus connect + embedding model + // preload; ~15-20s observed) before failing to bind and dying, + // during which `serve status` would misreport a healthy "running" + // for a doomed process while the real server is never touched. + if tcp_probe(bind, Duration::from_millis(500)) { + println!( + "something is already listening on {bind}, but not a process this CLI launched (no matching pidfile). Not starting a second instance — if that server is stale, stop it manually first." + ); + return Ok(()); + } std::fs::create_dir_all(mfs_home()).map_err(|e| e.to_string())?; let log = std::fs::File::create(&log_file).map_err(|e| e.to_string())?; let log_err = log.try_clone().map_err(|e| e.to_string())?; - let child = std::process::Command::new("mfs-server") + let mut child = std::process::Command::new("mfs-server") .args(["run", "--bind", bind]) .stdout(std::process::Stdio::from(log)) .stderr(std::process::Stdio::from(log_err)) .spawn() .map_err(|e| format!("failed to spawn mfs-server: {e}"))?; - std::fs::write(&pid_file, child.id().to_string()).map_err(|e| e.to_string())?; - println!( - "started mfs-server (pid {}) on {bind}; logs: {}", - child.id(), - log_file.display() - ); + + // Don't declare success the instant spawn() returns — that only + // confirms fork/exec worked, not that the server ever became + // reachable. Poll for either a successful bind or the child + // exiting early, and report whichever actually happened. + let deadline = Instant::now() + Duration::from_secs(45); + loop { + if tcp_probe(bind, Duration::from_millis(300)) { + std::fs::write(&pid_file, child.id().to_string()).map_err(|e| e.to_string())?; + println!( + "started mfs-server (pid {}) on {bind}; logs: {}", + child.id(), + log_file.display() + ); + break; + } + if let Ok(Some(status)) = child.try_wait() { + let tail = tail_lines(&log_file, 10); + return Err(format!( + "mfs-server exited during startup ({status}); last log lines:\n{tail}" + )); + } + if Instant::now() >= deadline { + return Err(format!( + "mfs-server on {bind} did not become reachable within 45s (pid {} still running); check {}", + child.id(), + log_file.display() + )); + } + std::thread::sleep(Duration::from_millis(500)); + } } ServeAction::Stop => match read_pid(&pid_file) { Some(pid) => { @@ -1560,13 +1683,43 @@ fn serve_cmd(action: &ServeAction) -> Result<(), String> { let _ = std::process::Command::new("kill") .arg(pid.to_string()) .status(); + // `kill` only sends the signal — it doesn't wait for the + // process to actually exit. uvicorn's graceful shutdown + // takes a moment, and Start's new "refuse if the port is + // already occupied" pre-flight check would otherwise race + // against the still-closing old process and mistake it for + // a foreign server, refusing to start the replacement and + // leaving nothing running at all. Wait for either the pid + // to die or the port to free up before proceeding. + let deadline = Instant::now() + Duration::from_secs(10); + while pid_alive(pid) && tcp_probe(bind, Duration::from_millis(200)) { + if Instant::now() >= deadline { + let _ = std::process::Command::new("kill") + .args(["-9", &pid.to_string()]) + .status(); + break; + } + std::thread::sleep(Duration::from_millis(200)); + } let _ = std::fs::remove_file(&pid_file); } return serve_cmd(&ServeAction::Start { bind: bind.clone() }); } - ServeAction::Status => match read_pid(&pid_file) { + ServeAction::Status { bind } => match read_pid(&pid_file) { Some(pid) if pid_alive(pid) => println!("running (pid {pid})"), - _ => println!("not running"), + _ => { + // A dead/missing pidfile only means this CLI isn't tracking + // anything — it doesn't mean nothing is actually serving. + // Probe the port so this doesn't flatly contradict a server + // that's genuinely up but was started outside `mfs serve`. + if tcp_probe(bind, Duration::from_millis(500)) { + println!( + "not running (no pidfile match) — but something IS listening on {bind}, started outside `mfs serve`" + ); + } else { + println!("not running"); + } + } }, ServeAction::Logs => { let s = std::fs::read_to_string(&log_file).unwrap_or_default(); @@ -1585,6 +1738,27 @@ fn serve_cmd(action: &ServeAction) -> Result<(), String> { Ok(()) } +/// Cheap "is anything listening here" check — a plain TCP connect, not an +/// HTTP/health call, so it works before we know if the answering process +/// even speaks HTTP. Used to distinguish a real occupant of `bind` from +/// "nothing here yet" without waiting out a full server startup. +fn tcp_probe(bind: &str, timeout: Duration) -> bool { + let Ok(addr) = bind.parse() else { return false }; + TcpStream::connect_timeout(&addr, timeout).is_ok() +} + +fn tail_lines(path: &PathBuf, n: usize) -> String { + let s = std::fs::read_to_string(path).unwrap_or_default(); + s.lines() + .rev() + .take(n) + .collect::>() + .into_iter() + .rev() + .collect::>() + .join("\n") +} + fn read_pid(p: &PathBuf) -> Option { std::fs::read_to_string(p) .ok() @@ -1644,6 +1818,41 @@ mod tests { std::fs::write(path, content).unwrap(); } + #[test] + fn local_fs_path_from_target_strips_file_uri_forms() { + assert_eq!( + local_fs_path_from_target("file:///tmp/foo"), + Some("/tmp/foo".to_string()) + ); + assert_eq!( + local_fs_path_from_target("file://local/tmp/foo"), + Some("/tmp/foo".to_string()) + ); + assert_eq!(local_fs_path_from_target("/tmp/foo"), None); + assert_eq!(local_fs_path_from_target("file://cid-1/tmp/foo"), None); + assert_eq!(local_fs_path_from_target("postgres://db/x"), None); + } + + #[test] + fn remote_path_normalizes_file_uri_forms_for_local_server() { + let dir = temp_tree("remote-path"); + let dir_str = dir.to_string_lossy().to_string(); + let base = "http://127.0.0.1:13619"; + // bare path passes through unchanged -- the server abspath()s it itself. + assert_eq!(remote_path(base, &dir_str).unwrap(), dir_str); + // both file:// spellings of the SAME local path must resolve identically, + // to the canonical form the server's root_uri is registered under. + let want = format!("file://local{dir_str}"); + assert_eq!( + remote_path(base, &format!("file://{dir_str}")).unwrap(), + want + ); + assert_eq!( + remote_path(base, &format!("file://local{dir_str}")).unwrap(), + want + ); + } + #[test] fn upload_scan_honors_ignore_files() { let root = temp_tree("scan-ignore"); @@ -1765,6 +1974,51 @@ mod tests { ); } + #[test] + fn check_query_length_accepts_value_just_under_the_limit() { + let value = "a".repeat(MAX_QUERY_CHARS); + assert!(check_query_length("query", "query_too_long", &value).is_ok()); + } + + #[test] + fn check_query_length_rejects_oversized_search_query_without_echoing_it() { + let value = "a".repeat(100_000); + let err = check_query_length("query", "query_too_long", &value).unwrap_err(); + + assert_eq!(err.code, "query_too_long"); + assert_eq!(err.detail, "query too long (100000 chars, max 8000)"); + assert!(!err.detail.contains(&value)); + } + + #[test] + fn check_query_length_rejects_oversized_grep_pattern_without_echoing_it() { + let value = "a".repeat(100_000); + let err = check_query_length("pattern", "pattern_too_long", &value).unwrap_err(); + + assert_eq!(err.code, "pattern_too_long"); + assert_eq!(err.detail, "pattern too long (100000 chars, max 8000)"); + assert!(!err.detail.contains(&value)); + } + + #[test] + fn validate_profile_url_accepts_http_and_https() { + assert!(validate_profile_url("http://127.0.0.1:13619").is_ok()); + assert!(validate_profile_url("https://mfs.example.com").is_ok()); + } + + #[test] + fn validate_profile_url_rejects_garbage_and_wrong_scheme() { + for bad in ["not a url at all", "", "ftp://wrong-scheme.example.com"] { + let err = validate_profile_url(bad).unwrap_err(); + assert!( + err.detail + .contains("must be a valid http:// or https:// URL"), + "unexpected error for {bad:?}: {}", + err.detail + ); + } + } + #[test] fn parse_client_cfg_reports_invalid_toml() { let err = diff --git a/docs/search-and-browse.md b/docs/search-and-browse.md index 7bc7e9ef..7116a1bc 100644 --- a/docs/search-and-browse.md +++ b/docs/search-and-browse.md @@ -135,7 +135,9 @@ Grep JSON is smaller: ``` `via` can identify how grep found the match, for example pushdown, BM25, linear -scan, or a notice. +scan, or a notice. A `bm25` match is keyword-based, not an exact-literal or +regex match — expect token-level matching rather than character-for-character +matches on database-backed sources. ## Reopen File-Like Hits diff --git a/server/python/src/mfs_server/api/app.py b/server/python/src/mfs_server/api/app.py index 2414a205..4e4f4da3 100644 --- a/server/python/src/mfs_server/api/app.py +++ b/server/python/src/mfs_server/api/app.py @@ -8,7 +8,7 @@ import asyncio from contextlib import asynccontextmanager -from typing import Literal +from typing import Literal, get_args from fastapi import FastAPI, HTTPException, Request from fastapi.exceptions import RequestValidationError @@ -21,6 +21,7 @@ from .. import __version__ from ..common.logging import configure_logging from ..config import ServerConfig, load_server_config +from ..connectors.base import ChunkKind from ..engine.engine import Engine from .models import ( AddRequest, @@ -43,6 +44,8 @@ StatusResponse, ) +_VALID_CHUNK_KINDS = get_args(ChunkKind) + # Canonical error codes -> suggested next actions. The endpoints # raise HTTPException with the canonical code as `detail` for these cases; the handler # below turns that into the stable {code, detail, suggestions} envelope SDKs switch on. @@ -54,6 +57,10 @@ "tail_unsupported": ["head", "cat --range"], "locator_not_found": ["re-search; the record may have changed"], "since_unsupported": ["drop --since"], + "config_required": [ + "pass --config with the connector's full config; omitting it would silently drop " + "the existing stored configuration" + ], "sync_already_running": ["mfs job list", "mfs job cancel JOB_ID"], "connector_removing": ["wait for removal to finish, then retry"], "remove_requires_connector_root": [ @@ -477,7 +484,17 @@ async def search( if path: connector_uri, object_prefix = await eng().resolve_connector_uri(path) # comma-separated chunk_kinds, e.g. ?kind=body,directory_summary - chunk_kinds = [k.strip() for k in kind.split(",") if k.strip()] if kind else None + chunk_kinds = None + if kind is not None: + chunk_kinds = [k.strip() for k in kind.split(",")] + invalid = sorted({k for k in chunk_kinds if k not in _VALID_CHUNK_KINDS}) + if invalid: + bad = ", ".join(repr(k) for k in invalid) + raise HTTPException( + 400, + f"unknown chunk kind(s): {bad} -- valid kinds are: " + f"{', '.join(_VALID_CHUNK_KINDS)}", + ) try: results = await eng().search( q, @@ -555,6 +572,13 @@ async def cat( loc = _json.loads(locator) except ValueError: raise HTTPException(400, "invalid locator JSON") + # A syntactically valid JSON value that isn't an object (array, number, + # string, bool, or null) can never match a record's locator_fields -- + # reject it as malformed rather than letting eng().cat() either crash + # on non-dict.get()/`in` or (for null) silently fall through to the + # "no locator given" path, which would hide a real client-side bug. + if not isinstance(loc, dict): + raise HTTPException(400, "invalid locator JSON") try: out = await eng().cat(path, range=rg, meta=meta, density=density, locator=loc) except IsADirectoryError: diff --git a/server/python/src/mfs_server/connectors/base.py b/server/python/src/mfs_server/connectors/base.py index 7caf7772..4acb1c9d 100644 --- a/server/python/src/mfs_server/connectors/base.py +++ b/server/python/src/mfs_server/connectors/base.py @@ -484,8 +484,9 @@ async def healthcheck(self) -> HealthStatus: # only constructs an httpx.AsyncClient at connect() time will say # ok=True even with a 401 token until the first real call. Override # this with a cheap round-trip (GitHub /repos/{o}/{r}, Slack - # auth.test, etc.) when correctness matters for the probe UX. The - # github connector does; the rest currently inherit this default. + # auth.test, S3 head_bucket, etc.) when correctness matters for the + # probe UX. Most connectors do; check a given plugin's own file + # before assuming it still inherits this no-op. return HealthStatus(ok=True) async def introspect_for_wizard(self) -> dict[str, dict]: diff --git a/server/python/src/mfs_server/connectors/file/plugin.py b/server/python/src/mfs_server/connectors/file/plugin.py index e1d73dbd..e2e09ea5 100644 --- a/server/python/src/mfs_server/connectors/file/plugin.py +++ b/server/python/src/mfs_server/connectors/file/plugin.py @@ -444,6 +444,8 @@ async def list(self, path: str) -> list[Entry]: # --- read --- async def read(self, path: str, range: Optional[Range] = None) -> AsyncIterator[bytes]: real = self._real(path) + if not real.exists(): + raise FileNotFoundError(path) if range is None: with open(real, "rb") as f: while chunk := f.read(65536): diff --git a/server/python/src/mfs_server/connectors/s3/plugin.py b/server/python/src/mfs_server/connectors/s3/plugin.py index 67dba4a5..26010735 100644 --- a/server/python/src/mfs_server/connectors/s3/plugin.py +++ b/server/python/src/mfs_server/connectors/s3/plugin.py @@ -21,6 +21,7 @@ Capabilities, ConnectorPlugin, Entry, + HealthStatus, ObjectChange, ObjectKind, PathStat, @@ -73,6 +74,30 @@ def _client_kwargs(self) -> dict: kw["endpoint_url"] = self._cfg("endpoint_url") return kw + async def healthcheck(self) -> HealthStatus: + # The base default never opens a real connection, so a bad access + # key, wrong bucket, or unreachable endpoint would probe clean and + # only surface once a real sync ran and failed. Prefer + # list_objects_v2(MaxKeys=1) over head_bucket: verified against the + # real (currently-broken) test bucket that head_bucket collapses + # both "bad credentials" and "bucket doesn't exist" into an + # undifferentiated 403, while list_objects_v2 surfaces the actual + # error code (e.g. InvalidAccessKeyId) — same cost, better diagnostic. + from botocore.exceptions import BotoCoreError, ClientError + + bucket = self._bucket() + if not bucket: + return HealthStatus(ok=False, detail="no bucket configured") + try: + async with self._session().client("s3", **self._client_kwargs()) as s3: + await s3.list_objects_v2(Bucket=bucket, MaxKeys=1) + except ClientError as e: + code = e.response.get("Error", {}).get("Code", "?") + return HealthStatus(ok=False, detail=f"{bucket}: {code}") + except BotoCoreError as e: + return HealthStatus(ok=False, detail=f"network error reaching {bucket}: {e}") + return HealthStatus(ok=True, detail=f"bucket {bucket} reachable") + def object_kind_of(self, path: str) -> ObjectKind: ext = os.path.splitext(path)[1].lower() if ext in CODE_EXT: diff --git a/server/python/src/mfs_server/connectors/web/plugin.py b/server/python/src/mfs_server/connectors/web/plugin.py index cdc74402..5c5e57b2 100644 --- a/server/python/src/mfs_server/connectors/web/plugin.py +++ b/server/python/src/mfs_server/connectors/web/plugin.py @@ -17,6 +17,7 @@ Capabilities, ConnectorPlugin, Entry, + HealthStatus, ObjectChange, ObjectKind, PathStat, @@ -81,7 +82,49 @@ def url_to_path(url: str) -> str: def _allowed(self, url: str) -> bool: domains = self._cfg("allowed_domains", []) or [] - return urlparse(url).netloc in domains if domains else True + if not domains: + return True + u = urlparse(url) + # An `allowed_domains` entry with a port (e.g. "127.0.0.1:18080") + # must match the URL's exact host:port. An entry with no port (e.g. + # "example.com") matches that host on ANY port, via `.hostname` + # (port-stripped) -- previously every entry was compared against + # the port-inclusive `netloc`, so a bare-host entry never matched a + # non-default-port URL, including the connector's own seed URL. + for d in domains: + if ":" in d: + if u.netloc == d: + return True + elif u.hostname == d: + return True + return False + + async def healthcheck(self) -> HealthStatus: + # The base default never makes a request, so a seed URL excluded by + # allowed_domains (or simply unreachable) would probe clean and only + # surface once a real sync crawled zero pages. Run the same + # _allowed() gate sync() uses, then a cheap GET against one seed URL. + import aiohttp + + start = list(self._cfg("start_urls", []) or []) + if not start: + return HealthStatus(ok=False, detail="no start_urls configured") + blocked = [u for u in start if not self._allowed(u)] + if blocked: + return HealthStatus( + ok=False, + detail=f"allowed_domains excludes {len(blocked)} of {len(start)} start_urls: {blocked[:3]}", + ) + try: + async with aiohttp.ClientSession(headers={"User-Agent": "mfs-web/0.4"}) as sess: + async with sess.get(start[0], timeout=aiohttp.ClientTimeout(total=10)) as resp: + if resp.status >= 400: + return HealthStatus( + ok=False, detail=f"{start[0]} returned HTTP {resp.status}" + ) + except Exception as e: # noqa: BLE001 + return HealthStatus(ok=False, detail=f"network error reaching {start[0]}: {e}") + return HealthStatus(ok=True, detail=f"{len(start)} start_url(s) allowed, seed reachable") def object_kind_of(self, path: str) -> ObjectKind: return "document" if path.endswith(".md") else "directory" @@ -236,4 +279,16 @@ async def sync(self, opts: SyncOptions) -> AsyncIterator[ObjectChange]: if link not in visited: queue.append(link) crawled += 1 + if crawled == 0 and start and not any(self._allowed(u) for u in start): + # Every seed URL was excluded by allowed_domains -- this connector + # can never crawl anything as configured. That's almost never + # intentional (vs. e.g. a since-filter or transient fetch errors + # legitimately yielding zero *new* pages this run), so fail the + # job loudly instead of silently persisting an empty, "succeeded" + # connector -- previously the only way to notice was to realize + # search was returning nothing. + raise ValueError( + "0 pages crawled: every start_url was excluded by allowed_domains " + f"(start_urls={start!r}, allowed_domains={self._cfg('allowed_domains')!r})" + ) await self.state.set("pages", pages) diff --git a/server/python/src/mfs_server/engine/components/connector_factory.py b/server/python/src/mfs_server/engine/components/connector_factory.py index c6145078..09f1ae31 100644 --- a/server/python/src/mfs_server/engine/components/connector_factory.py +++ b/server/python/src/mfs_server/engine/components/connector_factory.py @@ -139,6 +139,35 @@ class CredentialService: _CONN_URI_RE = re.compile(r"[a-zA-Z][a-zA-Z0-9+.\-]*://[^/\s:@]+:[^/\s@]+@") _REDACTED = "" + # --- validate --- + + @classmethod + def validate_no_plaintext_secrets(cls, value: Any, key_path: str = "") -> None: + """Recursively reject inline plaintext secrets BEFORE redact()/persistence. + A caller must supply credentials as env:VAR / file:/abs/path references; a + raw value under a secret-looking key (or an inline user:pass@host + connection string) is rejected outright rather than silently masked, so a + rebuild can never end up resolving a redaction placeholder as if it were a + real credential. Mirrors redact()'s detection surface exactly.""" + if isinstance(value, dict): + for k, v in value.items(): + cls.validate_no_plaintext_secrets(v, f"{key_path}.{k}" if key_path else str(k)) + return + if isinstance(value, list): + for v in value: + cls.validate_no_plaintext_secrets(v, key_path) + return + if not isinstance(value, str) or value in ("", None): + return + if value.startswith(cls._CRED_REF_PREFIXES): + return # a safe credential reference + leaf_key = key_path.rsplit(".", 1)[-1] if key_path else "" + if cls.is_secret_key(leaf_key) or cls._CONN_URI_RE.search(value): + raise ValueError( + f"plaintext secret in config field {key_path!r}: use " + "credential_ref=env:VAR or file:/abs/path, not a literal value" + ) + # --- redact --- @classmethod @@ -150,8 +179,12 @@ def is_secret_key(cls, key: str) -> bool: def redact(cls, value: Any, key_is_secret: bool = False) -> Any: """Recursively redact raw inline secrets from a config before persistence. A credential_ref (env:/secret:/file:/vault:) is kept; anything else under a - secret-looking key is replaced. Recurses into dicts/lists so nested OAuth - token dicts don't leak. Verbatim migration of ``_redact_config``.""" + secret-looking key is replaced with None (never a placeholder string a + plugin could mistake for a real value). Recurses into dicts/lists so nested + OAuth token dicts don't leak. Verbatim migration of ``_redact_config``, + except the placeholder is now None instead of a literal sentinel string — + callers should have already rejected plaintext secrets via + ``validate_no_plaintext_secrets``; this is defense-in-depth only.""" if isinstance(value, dict): return {k: cls.redact(v, cls.is_secret_key(k)) for k, v in value.items()} if isinstance(value, list): @@ -159,11 +192,11 @@ def redact(cls, value: Any, key_is_secret: bool = False) -> Any: if isinstance(value, str) and value.startswith(cls._CRED_REF_PREFIXES): return value # a safe credential reference, keep as-is if key_is_secret and value not in (None, "", [], {}): - return cls._REDACTED + return None # value-level catch: an inline connection string carrying a password leaks # via a field name (dsn/uri/url/connection) that doesn't look secret — redact by shape. if isinstance(value, str) and cls._CONN_URI_RE.search(value): - return cls._REDACTED + return None return value # --- resolve --- @@ -177,6 +210,14 @@ def resolve(value: Any) -> Any: a working ref and silently fail auth. Verbatim migration of ``_resolve_ref``.""" if not isinstance(value, str): return value + if value == CredentialService._REDACTED: + # defense-in-depth: a pre-fix row (or any other path that still produced + # the old sentinel string) must never be resolved as a literal credential. + raise ValueError( + f"credential_ref {value!r}: this is a redaction placeholder, not a " + "real credential — re-register the connector with credential_ref=env:VAR " + "or file:/abs/path" + ) if value.startswith("env:"): name = value[4:] if name not in os.environ: @@ -367,6 +408,12 @@ def resolve_target(self, target: str) -> TargetResolution: # --- credentials (single security entry point) --- + def validate_credentials(self, config: Any) -> None: + """Reject plaintext secrets in a caller-supplied config before it is ever + redacted/persisted. Callers that register/update a connector MUST call + this before ``redact``.""" + self._creds.validate_no_plaintext_secrets(config) + def redact(self, config: Any) -> Any: """Recursively redact inline secrets before persistence. ObjectRepository MUST call this before writing a connectors row.""" diff --git a/server/python/src/mfs_server/engine/engine.py b/server/python/src/mfs_server/engine/engine.py index 586962b3..430cd97a 100644 --- a/server/python/src/mfs_server/engine/engine.py +++ b/server/python/src/mfs_server/engine/engine.py @@ -549,10 +549,19 @@ def _redact_config(cls, value, key_is_secret: bool = False): return CredentialService.redact(value, key_is_secret) async def register_or_get_connector( - self, connector_uri: str, ctype: str, config: dict, overwrite_config: bool = False + self, + connector_uri: str, + ctype: str, + config: dict, + overwrite_config: bool = False, + config_explicit: bool = True, ) -> str: import json + # reject plaintext secrets outright — redact() is defense-in-depth, not the + # gate; a rejected literal here can never round-trip into a stored + # placeholder that a later rebuild mistakes for a real credential. + self.connector_factory.validate_credentials(config) stored = self.connector_factory.redact(config) row = await self.objects.get_connector_id_and_config_by_uri(connector_uri) if row: @@ -567,6 +576,13 @@ async def register_or_get_connector( new_json = json.dumps(stored, sort_keys=True) old_json = row["config_json"] or "{}" drift = _normalize_json(new_json) != _normalize_json(old_json) + if drift and not config_explicit: + # --config was omitted entirely: `config` here is just a URI-derived + # default, not something the caller asked for. Persisting it would + # silently drop the real stored config (credentials, schemas, + # [[objects]] mappings) with zero warning — refuse instead of + # guessing. Nothing has been written yet at this point. + raise ValueError("config_required") if overwrite_config or drift: await self.objects.update_connector_config(row["id"], json.dumps(stored)) if drift and not overwrite_config: @@ -630,7 +646,11 @@ async def add( cfg_dict = {**default_config, **config} if config is not None else default_config existing_connector = await self.objects.get_connector_id_by_uri(connector_uri) cid = await self.register_or_get_connector( - connector_uri, ctype, cfg_dict, overwrite_config=update_config + connector_uri, + ctype, + cfg_dict, + overwrite_config=update_config, + config_explicit=config is not None, ) row0 = await self.objects.get_connector_config_and_status(cid) if row0 and row0["status"] == "removing": @@ -1023,12 +1043,18 @@ async def _finalize_job(self, job_id: str, aborted: str | None) -> None: # --- standalone worker: poll DB queue, process queued jobs --- async def cancel_job(self, job_id: str) -> bool: """Cancel a job: mark it + its pending/running tasks cancelled. A running - worker stops at the next per-object boundary (checked in _run_job).""" + worker stops at the next per-object boundary (checked in _run_job). That + boundary is between objects, not within one -- a single large object already + mid-embed keeps running regardless, so also tell the embed consumer so its + NEXT flush (not the one already in flight) skips this job's chunks instead + of spending real embed time on work that will never be written.""" status = await self.objects.get_job_status(job_id) if not status or status in ("succeeded", "failed", "cancelled"): return False await self.objects.cancel_pending_running_tasks_for_job(job_id) await self.objects.cancel_job_row(job_id) + if self._embed_consumer is not None: + self._embed_consumer.mark_job_cancelled(job_id) return True async def _claim_queued_job(self) -> dict | None: @@ -1719,11 +1745,30 @@ async def resolve_connector_uri(self, target: str) -> tuple[str, str | None]: object_prefix = (connector_uri + rel) if rel not in ("", "/") else None return connector_uri, object_prefix + async def _resolve_readonly_config( + self, connector_uri: str, config: dict | None, default_config: dict + ) -> dict: + """Config resolution shared by probe()/estimate(). When `--config` is + omitted, reuse an already-registered connector's stored config (as + inspect() does) instead of silently falling back to a URI-derived + default — for schemes where the URI alone can't reconstruct real + connection info (postgres/mysql/mongo/s3/web), that default is `{}`, + which produces a connection to nothing meaningful (e.g. postgres + falling through to libpq's OS-user ambient defaults) while still + reporting a real-looking failure, misleading the caller into thinking + their actual registered connector is broken.""" + if config is not None: + return {**default_config, **config} + row = await self.objects.get_connector_id_and_config_by_uri(connector_uri) + if row and row["config_json"]: + return json.loads(row["config_json"]) + return default_config + # --- connector management: probe / inspect / remove --- async def probe(self, target: str, config: dict | None = None) -> dict: """Try-connect a connector without registering or writing state.""" _, connector_uri, ctype, default_config = self._resolve_target(target) - cfg_dict = {**default_config, **config} if config is not None else default_config + cfg_dict = await self._resolve_readonly_config(connector_uri, config, default_config) plugin = None try: # Build inside the guard: _build_plugin resolves credential refs (_resolve_ref), @@ -1762,7 +1807,7 @@ async def estimate( from ..processors.text import chunk_body _, connector_uri, ctype, default_config = self._resolve_target(target) - cfg_dict = {**default_config, **config} if config is not None else default_config + cfg_dict = await self._resolve_readonly_config(connector_uri, config, default_config) tmp_cid = "estimate-" + uuid.uuid4().hex plugin, _ = self._build_plugin(ctype, cfg_dict, tmp_cid) await plugin.connect() @@ -2377,6 +2422,12 @@ async def grep( cid, curi, rel, plugin = await self._open_path(path) scope_prefix = (curi + rel) if rel != "/" else None try: + # _open_path only resolves which connector owns the prefix, not whether + # `rel` exists under it -- unlike ls/cat, nothing downstream fails loudly + # for a missing path (pushdown/BM25/linear-scan all just yield zero + # matches). Stat it explicitly so a bad path 404s like ls/cat instead of + # looking like a real, empty search. + await plugin.stat(rel) results: list[dict] = [] # 2a connector grep pushdown: exact, source-side (e.g. # SQL ILIKE for structured connectors). Returns None when unsupported. diff --git a/server/python/src/mfs_server/engine/pipeline.py b/server/python/src/mfs_server/engine/pipeline.py index 82b120d1..494572d4 100644 --- a/server/python/src/mfs_server/engine/pipeline.py +++ b/server/python/src/mfs_server/engine/pipeline.py @@ -115,8 +115,10 @@ class _Shutdown: _SHUTDOWN = _Shutdown() # Success hook: (task_uri, job_id, chunk_count, partial, error). chunk_count is the total number -# of chunks received for the object; partial is True if any chunk's content was capped, the task -# was truncated (EndOfTask.partial), or a flush dropped some of its chunks. error is None on a +# of chunks actually persisted to Milvus for the object (credited only on a successful flush, +# post-dedup — never on a chunk merely received, and never on a batch that failed and was +# dropped); partial is True if any chunk's content was capped, the task was truncated +# (EndOfTask.partial), or a flush dropped some of its chunks. error is None on a # clean finalize, or ": " when an embed/upsert flush failed for this task — the # callback uses it to record a failed status + last_error. Callers derive search_status from # these without the producer returning them inline (§6.1). @@ -159,6 +161,15 @@ def __init__( # task_id -> ": " recorded when a flush dropped this task's chunks, so # finalize can mark the object failed + attach last_error. self._task_errors: dict[str, str] = {} + # job_ids Engine.cancel_job() has told us about (§ mfs job cancel). A single object's + # chunks can span many flushes; the producer/object-boundary cancellation check + # (Engine._should_stop) only stops the NEXT object from starting, so a large single + # object already mid-flight here would otherwise keep burning embed CPU until it + # finishes regardless of cancellation. Checked once per flush (not per chunk), so this + # can't slow down the common (nothing-cancelled) path. Never explicitly pruned: job ids + # are unique per sync and cancellation is a rare, human-triggered action, so this can't + # meaningfully grow over a process's lifetime. + self._cancelled_jobs: set[str] = set() self._on_succeeded: list[SuccessCallback] = [] self._q: Optional[asyncio.Queue] = None @@ -238,6 +249,16 @@ def on_task_retry(self, task_id: str) -> None: self._meta.pop(task_id, None) self._task_errors.pop(task_id, None) + # --- cancellation --- + def mark_job_cancelled(self, job_id: str) -> None: + """Record a job as cancelled (called from Engine.cancel_job()) so the NEXT flush + skips embedding its already-queued chunks instead of spending real embed-API/CPU + time on work whose result will never be written. This can't make an in-flight + batch_embed() call return early -- that one batch still runs to completion -- but + it stops every batch after it, which is the realistic bound on how fast `mfs job + cancel` can actually take effect for a single large object still mid-flight.""" + self._cancelled_jobs.add(job_id) + # --- consume --- async def _consume(self, env: TaskEnvelope) -> None: payload = env.payload @@ -256,9 +277,6 @@ async def _handle_chunk(self, env: TaskEnvelope, chunk: Chunk) -> None: await self._milvus.delete_by_object(env.connector_uri, env.task_uri) self._deleted.add(tid) self._pending[tid] = self._pending.get(tid, 0) + 1 - self._count[tid] = ( - self._count.get(tid, 0) + 1 - ) # cumulative: every received chunk is written if chunk.partial: self._partial[tid] = True self._batch.append((env, chunk)) @@ -284,8 +302,22 @@ async def _flush(self) -> None: if not self._batch: return # No concurrent appender can grow it mid-flush: _flush only runs inside the single - # run() loop, so this snapshot IS the whole pending batch. + # run() loop, so this snapshot IS the whole pending batch. Claiming it (rather than + # just aliasing it) up front is safe for the same reason: nothing else touches + # self._batch while this coroutine is suspended on an await below. batch = self._batch + self._batch = [] + if self._cancelled_jobs: + cancelled = [item for item in batch if item[0].job_id in self._cancelled_jobs] + if cancelled: + batch = [item for item in batch if item[0].job_id not in self._cancelled_jobs] + # Reuse the existing failure path rather than inventing new bookkeeping -- + # cancellation was already called out as one of _fail_batch's intended + # triggers ("an OOM embed, a Milvus error, a cancellation") when that method + # was written; this is the first caller to actually exercise it that way. + await self._fail_batch(cancelled, RuntimeError("job_cancelled")) + if not batch: + return try: # 1. tx_cache lookup for vectors, embed only the misses (§6.3), cache them back. keys = [self._cache_key(ch) for _, ch in batch] @@ -301,7 +333,23 @@ async def _flush(self) -> None: # 2. one upsert for the whole (cross-task, cross-kind) batch — idempotent by # chunk_id PK (§5.3 / §6.2). delete_by_object already ran per task on receipt. - rows = [self._build_row(env, ch, cached[keys[i]]) for i, (env, ch) in enumerate(batch)] + built = [ + (env.task_id, self._build_row(env, ch, cached[keys[i]])) + for i, (env, ch) in enumerate(batch) + ] + # de-dupe by chunk_id (last-occurrence wins) so two chunks colliding on the same + # locator-derived id within one batch don't make Milvus reject the WHOLE batch — + # rows without a chunk_id (the base class's default _build_row, used directly in + # unit tests) skip this, since there is nothing to collide on. + if built and "chunk_id" in built[0][1]: + deduped: dict[str, tuple[str, dict]] = {} + for tid, row in built: + deduped[row["chunk_id"]] = (tid, row) + built = list(deduped.values()) + rows = [row for _, row in built] + # raw per-task counts across the WHOLE flush attempt (every chunk processed, dedup + # or not) — pending tracks "no longer queued for a future flush", true regardless + # of whether the row survived dedup or the upsert ultimately succeeds. counts: dict[str, int] = {} for env, _ in batch: counts[env.task_id] = counts.get(env.task_id, 0) + 1 @@ -315,8 +363,15 @@ async def _flush(self) -> None: await self._fail_batch(batch, exc) return - # 3. write acked: clear the batch, then decrement pending and finalize. + # 3. write acked: clear the batch, credit chunk_count for what actually persisted + # (the deduped rows — not the raw per-task counts above), then decrement pending and + # finalize. self._batch = [] + written: dict[str, int] = {} + for tid, _ in built: + written[tid] = written.get(tid, 0) + 1 + for tid, n in written.items(): + self._count[tid] = self._count.get(tid, 0) + n for tid, n in counts.items(): self._pending[tid] = self._pending.get(tid, 0) - n await self._maybe_finalize(tid) diff --git a/server/python/src/mfs_server/server/__main__.py b/server/python/src/mfs_server/server/__main__.py index ff7565db..c1e8b555 100644 --- a/server/python/src/mfs_server/server/__main__.py +++ b/server/python/src/mfs_server/server/__main__.py @@ -101,6 +101,15 @@ def main(argv: list[str] | None = None) -> int: cfg = load_server_config(args.config) configure_logging() + from .. import __version__ + + # The startup log previously had no version string anywhere, so + # confirming which build is actually running required cross-checking + # `ps aux` against known install paths. This is semver only (no git + # commit — mfs-server's package build doesn't stamp one the way the + # Rust CLI's build.rs now does for `mfs --version`), but it is at + # least visible without leaving the log. + logger.info("mfs-server %s starting", __version__) _ensure_auth_token(cfg) host, _, port = args.bind.partition(":") app = create_app(cfg, preload_local_models=True) diff --git a/server/python/tests/test_api_auth.py b/server/python/tests/test_api_auth.py index 68f109b2..0698b006 100644 --- a/server/python/tests/test_api_auth.py +++ b/server/python/tests/test_api_auth.py @@ -119,6 +119,69 @@ def test_search_rejects_unknown_query_params(tmp_path) -> None: } +def test_search_rejects_unknown_kind(tmp_path) -> None: + cfg = ServerConfig(home=str(tmp_path), auth_token="expected").resolve_defaults() + app = create_app(cfg) + client = TestClient(app) + + response = client.get( + "/v1/search?q=needle&kind=totally_bogus_kind", + headers={"Authorization": "Bearer expected"}, + ) + + assert response.status_code == 400 + body = response.json() + assert body["code"] == "bad_request" + assert "totally_bogus_kind" in body["detail"] + assert "body" in body["detail"] # valid kinds listed for guidance + + +def test_search_rejects_empty_string_kind(tmp_path) -> None: + cfg = ServerConfig(home=str(tmp_path), auth_token="expected").resolve_defaults() + app = create_app(cfg) + client = TestClient(app) + + response = client.get( + "/v1/search?q=needle&kind=", + headers={"Authorization": "Bearer expected"}, + ) + + assert response.status_code == 400 + assert response.json()["code"] == "bad_request" + + +def test_search_rejects_valid_kind_with_stray_empty_segment(tmp_path) -> None: + cfg = ServerConfig(home=str(tmp_path), auth_token="expected").resolve_defaults() + app = create_app(cfg) + client = TestClient(app) + + response = client.get( + "/v1/search?q=needle&kind=body,", + headers={"Authorization": "Bearer expected"}, + ) + + assert response.status_code == 400 + assert response.json()["code"] == "bad_request" + + +def test_search_accepts_valid_kinds(tmp_path) -> None: + cfg = ServerConfig(home=str(tmp_path), auth_token="expected").resolve_defaults() + app = create_app(cfg) + + with TestClient(app) as client: + client.headers["Authorization"] = "Bearer expected" + + single = client.get("/v1/search", params={"q": "needle", "kind": "body"}) + multiple = client.get( + "/v1/search", params={"q": "needle", "kind": "body,directory_summary"} + ) + omitted = client.get("/v1/search", params={"q": "needle"}) + + for response in (single, multiple, omitted): + assert response.status_code == 200 + assert response.json() == {"results": []} + + def test_framework_http_errors_use_documented_envelope(tmp_path) -> None: cfg = ServerConfig(home=str(tmp_path), auth_token="expected").resolve_defaults() app = create_app(cfg) diff --git a/server/python/tests/test_cat_locator.py b/server/python/tests/test_cat_locator.py new file mode 100644 index 00000000..a2514258 --- /dev/null +++ b/server/python/tests/test_cat_locator.py @@ -0,0 +1,139 @@ +"""GET /v1/cat?locator=... : the decoded locator must be a JSON object. + +A non-dict locator (array, number, string, bool, null) can never satisfy +_locator_matches's `k in locator` / `locator.get(k)` checks -- historically +that either raised an unhandled 500 (list/int/etc.) or, for `null`, silently +fell through to "no locator given" instead of erroring. Both are wrong: the +client asked for a specific record and typo'd the shape, it should get a +clean 400, not a crash or a different answer. +""" + +from __future__ import annotations + +import pytest +from fastapi.testclient import TestClient + +from mfs_server.config import ServerConfig +from mfs_server.connectors.base import ObjectConfig, PathStat +from mfs_server.engine.engine import Engine + + +def _app_client(tmp_path): + cfg = ServerConfig(home=str(tmp_path), auth_token="expected").resolve_defaults() + from mfs_server.api.app import create_app + + app = create_app(cfg) + client = TestClient(app) + client.headers["Authorization"] = "Bearer expected" + return client + + +@pytest.mark.parametrize( + "raw_locator", + ['["id","ord_1001"]', "42", "null", "true", '"just a string"'], +) +def test_cat_rejects_non_object_locator(tmp_path, raw_locator: str) -> None: + client = _app_client(tmp_path) + + resp = client.get("/v1/cat", params={"path": "postgres://db/orders", "locator": raw_locator}) + + assert resp.status_code == 400 + body = resp.json() + assert body["code"] == "bad_request" + assert body["detail"] == "invalid locator JSON" + + +def test_cat_rejects_malformed_locator_json_syntax(tmp_path) -> None: + client = _app_client(tmp_path) + + resp = client.get( + "/v1/cat", params={"path": "postgres://db/orders", "locator": "{not valid json"} + ) + + assert resp.status_code == 400 + assert resp.json()["detail"] == "invalid locator JSON" + + +# --- engine-level regression coverage: valid-object locators are unaffected --- + +_OCFG = ObjectConfig(text_fields=["title"], locator_fields=["id"]) + + +class _FakeConnCtx: + def object_config_for(self, path): + return _OCFG + + +class _FakeStructuredPlugin: + def __init__(self, records: list[dict]): + self._records = records + self.ctx = _FakeConnCtx() + self.closed = False + + async def stat(self, rel): + return PathStat( + path=rel, + type="file", + media_type="application/x-collection", + size_hint=1, + fingerprint="fp:" + rel, + ) + + def object_kind_of(self, rel): + return "table_rows" + + def read_records(self, rel, range=None): + recs = self._records + + async def gen(): + for r in recs: + yield r + + return gen() + + async def close(self) -> None: + self.closed = True + + +async def _build_engine(tmp_path) -> Engine: + cfg = ServerConfig() + cfg.metadata.backend = "sqlite" + cfg.metadata.path = str(tmp_path / "meta.db") + cfg.transformation_cache.backend = "sqlite" + cfg.transformation_cache.db_path = str(tmp_path / "tx.db") + cfg.artifact_cache.root = str(tmp_path / "art") + eng = Engine(cfg) + await eng.meta.connect() + await eng.meta.init_schema() + return eng + + +async def test_cat_with_matching_dict_locator_returns_the_record(tmp_path) -> None: + eng = await _build_engine(tmp_path) + plugin = _FakeStructuredPlugin([{"id": "ord_1001", "title": "widget"}]) + + async def fake_open_path(path: str): + return "cid", "postgres://db", "/orders", plugin + + eng._open_path = fake_open_path # type: ignore[method-assign] + + out = await eng.cat("postgres://db/orders", locator={"id": "ord_1001"}) + + assert out["locator"] == {"id": "ord_1001"} + assert "ord_1001" in out["content"] + await eng.meta.close() + + +async def test_cat_with_dict_locator_no_match_raises_locator_not_found(tmp_path) -> None: + eng = await _build_engine(tmp_path) + plugin = _FakeStructuredPlugin([{"id": "ord_1001", "title": "widget"}]) + + async def fake_open_path(path: str): + return "cid", "postgres://db", "/orders", plugin + + eng._open_path = fake_open_path # type: ignore[method-assign] + + with pytest.raises(ValueError, match="locator_not_found"): + await eng.cat("postgres://db/orders", locator={"id": "does_not_exist"}) + + await eng.meta.close() diff --git a/server/python/tests/test_connector_factory.py b/server/python/tests/test_connector_factory.py index d4d94421..59a696d9 100644 --- a/server/python/tests/test_connector_factory.py +++ b/server/python/tests/test_connector_factory.py @@ -178,16 +178,16 @@ class TestRedact: def test_top_level_secret_keys_redacted(self): for k in ("password", "api_key", "refresh_token", "access_key", "client_secret"): out = CredentialService.redact({k: "shh"}) - assert out[k] == CredentialService._REDACTED, k + assert out[k] is None, k def test_nested_dict_recursive(self): out = CredentialService.redact({"oauth": {"access_token": "t", "scope": "read"}}) - assert out["oauth"]["access_token"] == CredentialService._REDACTED + assert out["oauth"]["access_token"] is None assert out["oauth"]["scope"] == "read" def test_list_recursive(self): out = CredentialService.redact({"tokens": [{"token": "t"}, {"id": 1}]}) - assert out["tokens"][0]["token"] == CredentialService._REDACTED + assert out["tokens"][0]["token"] is None assert out["tokens"][1]["id"] == 1 def test_env_file_refs_preserved(self): @@ -201,7 +201,7 @@ def test_non_secret_plaintext_kept(self): def test_value_level_connection_string_redacted(self): out = CredentialService.redact({"url": "postgres://u:p@host/db"}) - assert out["url"] == CredentialService._REDACTED + assert out["url"] is None def test_plain_url_not_redacted(self): out = CredentialService.redact({"url": "https://example.com/path"}) @@ -220,9 +220,9 @@ def test_is_secret_key_case_insensitive(self): def test_unimplemented_scheme_under_secret_key_redacted(self): out = CredentialService.redact({"password": "secret:foo"}) - assert out["password"] == CredentialService._REDACTED + assert out["password"] is None out2 = CredentialService.redact({"token": "vault:bar"}) - assert out2["token"] == CredentialService._REDACTED + assert out2["token"] is None # =========================================================================== @@ -263,6 +263,39 @@ def test_non_string_passthrough(self): def test_no_prefix_passthrough(self): assert CredentialService.resolve("plain-value") == "plain-value" + def test_redaction_placeholder_rejected(self): + with pytest.raises(ValueError, match="redaction placeholder"): + CredentialService.resolve(CredentialService._REDACTED) + + +# =========================================================================== +# §7.1b CredentialService.validate_no_plaintext_secrets +# =========================================================================== + + +class TestValidateNoPlaintextSecrets: + def test_plaintext_secret_key_rejected(self): + with pytest.raises(ValueError, match="plaintext secret"): + CredentialService.validate_no_plaintext_secrets({"password": "shh"}) + + def test_plaintext_connection_string_rejected(self): + with pytest.raises(ValueError, match="plaintext secret"): + CredentialService.validate_no_plaintext_secrets({"uri": "postgres://u:p@host/db"}) + + def test_nested_plaintext_secret_rejected(self): + with pytest.raises(ValueError, match="plaintext secret"): + CredentialService.validate_no_plaintext_secrets({"oauth": {"access_token": "t"}}) + + def test_env_and_file_refs_accepted(self): + CredentialService.validate_no_plaintext_secrets({"password": "env:VAR", "key": "file:/p"}) + + def test_non_secret_plaintext_accepted(self): + CredentialService.validate_no_plaintext_secrets({"host": "db.example.com", "port": 5432}) + + def test_empty_values_accepted(self): + for v in (None, "", [], {}): + CredentialService.validate_no_plaintext_secrets({"password": v}) + # =========================================================================== # §7.4 PluginBuilder.build diff --git a/server/python/tests/test_engine_connector_lifecycle.py b/server/python/tests/test_engine_connector_lifecycle.py index 9e6395fa..771a7b0f 100644 --- a/server/python/tests/test_engine_connector_lifecycle.py +++ b/server/python/tests/test_engine_connector_lifecycle.py @@ -22,6 +22,9 @@ def on_sync_done(self, *args, **kwargs): def evict_job(self, *args, **kwargs): return None + def on_yield_object_change(self, *args, **kwargs): + return None + async def _build_engine(tmp_path) -> Engine: load_builtin() @@ -80,3 +83,124 @@ async def test_failed_initial_add_rolls_back_connector_registration(tmp_path): assert row["n"] == 0 finally: await eng.meta.close() + + +# ---------------------------------------------------------------------- +# add()/register_or_get_connector — omitted --config on an already-registered +# connector must never silently persist a drifted, URI-derived default over the +# real stored config. Previously, `mfs connector update ` with no +# --config would silently and permanently wipe the connector's stored +# credentials whenever the URI alone couldn't reconstruct them (postgres, +# mysql, mongo, s3, web). +# ---------------------------------------------------------------------- + + +async def _seed_file_root(tmp_path, name="repo"): + root = tmp_path / name + root.mkdir(parents=True, exist_ok=True) + (root / "a.md").write_text("hello") + return root + + +async def test_add_without_config_on_drifted_connector_is_rejected(tmp_path): + eng = await _build_engine(tmp_path) + try: + root = await _seed_file_root(tmp_path) + target = str(root) + + # initial registration with an explicit config whose client_id differs from + # what derive_target would reconstruct from the bare URI alone ("local") — + # this is the "real stored config" a later bare re-sync must not clobber. + await eng.add(target, config={"client_id": "custom"}, process=False) + row = await eng.objects.get_connector_id_and_config_by_uri(f"file://local{root}") + before = row["config_json"] + assert '"custom"' in before + + with pytest.raises(ValueError, match="config_required"): + await eng.add(target, config=None, process=False) + + after = (await eng.objects.get_connector_id_and_config_by_uri(f"file://local{root}"))[ + "config_json" + ] + assert after == before + finally: + await eng.meta.close() + + +async def test_connector_update_without_config_on_drifted_connector_is_rejected(tmp_path): + """Same as above but through the `mfs connector update` path (update_config=True).""" + eng = await _build_engine(tmp_path) + try: + root = await _seed_file_root(tmp_path) + target = str(root) + + await eng.add(target, config={"client_id": "custom"}, process=False) + row = await eng.objects.get_connector_id_and_config_by_uri(f"file://local{root}") + before = row["config_json"] + + with pytest.raises(ValueError, match="config_required"): + await eng.add(target, config=None, update_config=True, process=False) + + after = (await eng.objects.get_connector_id_and_config_by_uri(f"file://local{root}"))[ + "config_json" + ] + assert after == before + finally: + await eng.meta.close() + + +async def test_add_with_explicit_differing_config_still_persists(tmp_path): + """No regression: an actual --config that differs from the stored one must still + persist, unaffected by the new config_explicit guard.""" + eng = await _build_engine(tmp_path) + try: + root = await _seed_file_root(tmp_path) + target = str(root) + + job_id = await eng.add(target, config={"client_id": "custom"}, process=False) + await eng.cancel_job(job_id) # free the one-in-flight-sync slot for the 2nd add() + await eng.add(target, config={"client_id": "custom2"}, process=False) + + row = await eng.objects.get_connector_id_and_config_by_uri(f"file://local{root}") + assert '"custom2"' in row["config_json"] + finally: + await eng.meta.close() + + +async def test_add_without_config_on_undrifted_connector_is_a_silent_noop(tmp_path): + """No regression: when the URI-derived default happens to exactly match the + stored config (the minimal file:// case), a bare re-sync stays a safe no-op.""" + eng = await _build_engine(tmp_path) + try: + root = await _seed_file_root(tmp_path) + target = str(root) + + job_id = await eng.add(target, config=None, process=False) + row = await eng.objects.get_connector_id_and_config_by_uri(f"file://local{root}") + before = row["config_json"] + + await eng.cancel_job(job_id) # free the one-in-flight-sync slot for the 2nd add() + await eng.add(target, config=None, process=False) + + after = (await eng.objects.get_connector_id_and_config_by_uri(f"file://local{root}"))[ + "config_json" + ] + assert after == before + finally: + await eng.meta.close() + + +async def test_add_without_config_on_brand_new_connector_is_unaffected(tmp_path): + """Brand-new registration (no existing row) never hits the config_explicit guard.""" + eng = await _build_engine(tmp_path) + try: + root = await _seed_file_root(tmp_path) + target = str(root) + + job_id = await eng.add(target, config=None, process=False) + assert job_id + + row = await eng.objects.get_connector_id_and_config_by_uri(f"file://local{root}") + assert row is not None + finally: + await eng.meta.close() diff --git a/server/python/tests/test_file_connector_read_not_found.py b/server/python/tests/test_file_connector_read_not_found.py new file mode 100644 index 00000000..a52f283f --- /dev/null +++ b/server/python/tests/test_file_connector_read_not_found.py @@ -0,0 +1,59 @@ +"""read() on a nonexistent path (used by mfs head/tail/cat) must raise a plain +FileNotFoundError carrying only the connector-relative path -- not the raw OS +error, whose str() bakes in the absolute local filesystem path. stat()/list()/ +grep() in this same file already guard against that leak; read() must match. +""" + +from __future__ import annotations + +from typing import Any + +import pytest + +from mfs_server.connectors.base import ConnectorContext +from mfs_server.connectors.file.plugin import FileConfig, FilePlugin + + +class MemoryState: + async def get(self, key: str) -> Any | None: + return None + + async def set(self, key: str, value: Any) -> None: + return None + + async def delete(self, key: str) -> None: + return None + + async def checkpoint(self) -> None: + return None + + +def _plugin(root) -> FilePlugin: + ctx = ConnectorContext(MemoryState(), "connector-id", "namespace-id") + return FilePlugin(FileConfig(root=str(root)), None, ctx=ctx) + + +@pytest.mark.anyio +async def test_file_read_missing_path_raises_relative_not_found(tmp_path): + plugin = _plugin(tmp_path) + + with pytest.raises(FileNotFoundError) as excinfo: + async for _ in plugin.read("/does/not/exist.txt"): + pass + + assert str(excinfo.value) == "/does/not/exist.txt" + assert str(tmp_path) not in str(excinfo.value) + + +@pytest.mark.anyio +async def test_file_read_missing_path_with_range_raises_relative_not_found(tmp_path): + from mfs_server.connectors.base import Range + + plugin = _plugin(tmp_path) + + with pytest.raises(FileNotFoundError) as excinfo: + async for _ in plugin.read("/does/not/exist.txt", range=Range(start=0, end=5)): + pass + + assert str(excinfo.value) == "/does/not/exist.txt" + assert str(tmp_path) not in str(excinfo.value) diff --git a/server/python/tests/test_grep_not_found.py b/server/python/tests/test_grep_not_found.py new file mode 100644 index 00000000..dcc56c98 --- /dev/null +++ b/server/python/tests/test_grep_not_found.py @@ -0,0 +1,62 @@ +"""GET /v1/grep on a path that doesn't exist under its connector must 404, like +ls/cat -- not silently return zero results. Unlike ls/cat, grep's pushdown/BM25/ +linear-scan dispatch never touches the target path directly, so a missing path +just looks like a real search with no matches unless the engine checks existence +up front. +""" + +from __future__ import annotations + +import pytest + +from mfs_server.config import ServerConfig +from mfs_server.engine.engine import Engine + + +class _FakePlugin: + """Mimics FilePlugin.stat: raises FileNotFoundError for anything but /exists.txt.""" + + def __init__(self): + self.closed = False + + async def stat(self, rel): + if rel != "/exists.txt": + raise FileNotFoundError(rel) + from mfs_server.connectors.base import PathStat + + return PathStat(path=rel, type="file", media_type="text/plain", size_hint=1) + + async def grep(self, pattern, rel, options): + return None # no pushdown -> engine falls through to BM25/linear scan + + async def close(self) -> None: + self.closed = True + + +async def _build_engine(tmp_path) -> Engine: + cfg = ServerConfig() + cfg.metadata.backend = "sqlite" + cfg.metadata.path = str(tmp_path / "meta.db") + cfg.transformation_cache.backend = "sqlite" + cfg.transformation_cache.db_path = str(tmp_path / "tx.db") + cfg.artifact_cache.root = str(tmp_path / "art") + eng = Engine(cfg) + await eng.meta.connect() + await eng.meta.init_schema() + return eng + + +async def test_grep_missing_path_raises_file_not_found(tmp_path) -> None: + eng = await _build_engine(tmp_path) + plugin = _FakePlugin() + + async def fake_open_path(path: str): + return "cid", "file://local/root", "/does/not/exist", plugin + + eng._open_path = fake_open_path # type: ignore[method-assign] + + with pytest.raises(FileNotFoundError): + await eng.grep("needle", "file://local/root/does/not/exist") + + assert plugin.closed + await eng.meta.close() diff --git a/server/python/tests/test_pipeline.py b/server/python/tests/test_pipeline.py index 494fae25..49163a69 100644 --- a/server/python/tests/test_pipeline.py +++ b/server/python/tests/test_pipeline.py @@ -330,7 +330,7 @@ async def test_upsert_failure_drops_batch_and_finalizes_failed(): c, embedder, milvus, tx = _consumer(batch_size=10, idle_ms=30) finals: list[tuple] = [] c.register_on_succeeded( - lambda uri, job, count, partial, error: finals.append((uri, partial, error)) + lambda uri, job, count, partial, error: finals.append((uri, count, partial, error)) ) milvus.upsert = AsyncMock(side_effect=RuntimeError("milvus write error")) q = make_chunks_q(10) @@ -342,11 +342,89 @@ async def test_upsert_failure_drops_batch_and_finalizes_failed(): await c.shutdown() assert len(finals) == 1 - uri, partial, error = finals[0] - assert uri == "c://x/T" and partial is True and "RuntimeError" in error + uri, count, partial, error = finals[0] + # Bug C: chunk_count must reflect what actually persisted (nothing — the whole batch was + # dropped), never the attempted count. + assert uri == "c://x/T" and count == 0 and partial is True and "RuntimeError" in error assert c._pending == {} and c._task_errors == {} +async def test_duplicate_chunk_id_within_batch_dedupes_and_both_tasks_succeed(): + # Bug B: two chunks (from different tasks, but landing in the same flush) that hash to the + # same chunk_id must NOT crash the whole batch — dedupe by chunk_id (last-write-wins) before + # upsert, and both tasks still finalize successfully with the deduped count. + class _DedupingConsumer(EmbedConsumer): + def _build_row(self, env, chunk, vec): + row = super()._build_row(env, chunk, vec) + row["chunk_id"] = f"cid-{chunk.locator}" # deliberately collide via shared locator + return row + + embedder, milvus, tx = _mocks() + c = _DedupingConsumer(embedder, milvus, tx, batch_size=2, idle_ms=5000) + finals: dict[str, tuple] = {} + c.register_on_succeeded( + lambda uri, job, count, partial, error=None: finals.update({uri: (count, partial, error)}) + ) + q = make_chunks_q(2) + c.start(q) + + # A and B each contribute one chunk with the SAME locator -> same chunk_id, same flush. + await q.put(_chunk_env("A", "content-a", locator="shared")) + await q.put(_chunk_env("B", "content-b", locator="shared")) # triggers flush at batch_size=2 + await q.put(_eot_env("A")) + await q.put(_eot_env("B")) + await c.shutdown() + + assert milvus.upsert.call_count == 1 + rows = milvus.upsert.call_args.args[0] + assert len(rows) == 1 # deduped down to one row + assert rows[0]["content"] == "content-b" # last occurrence wins (B was appended after A) + assert rows[0]["chunk_id"] == "cid-shared" + + # both tasks finalize successfully (no error), not failed — the batch was never dropped. + assert finals["c://x/A"][2] is None + assert finals["c://x/B"][2] is None + # only the single surviving (deduped) row is credited, and it was B's (last-write-wins). + assert finals["c://x/A"] == (0, False, None) + assert finals["c://x/B"] == (1, False, None) + + +async def test_partial_success_then_failure_reports_only_successful_chunks(): + # The most important Bug C case: a task spanning two flushes where the first succeeds (2 + # chunks written) and the second fails (1 more chunk attempted) must finalize with + # chunk_count == 2 (only what actually persisted), not 3 (the attempted total). + c, embedder, milvus, tx = _consumer(batch_size=2, idle_ms=5000) + finals: dict[str, tuple] = {} + c.register_on_succeeded( + lambda uri, job, count, partial, error=None: finals.update({uri: (count, partial, error)}) + ) + + upsert_calls = {"n": 0} + + async def flaky_upsert(rows): + upsert_calls["n"] += 1 + if upsert_calls["n"] == 2: + raise RuntimeError("milvus write error") + + milvus.upsert = AsyncMock(side_effect=flaky_upsert) + q = make_chunks_q(2) + c.start(q) + + # first flush (2 chunks) succeeds. + await q.put(_chunk_env("T", "a")) + await q.put(_chunk_env("T", "b")) + # second flush (1 more chunk, below batch_size) fails via the idle-timeout drain. + await q.put(_chunk_env("T", "c")) + await q.put(_eot_env("T")) + await c.shutdown() + + assert upsert_calls["n"] == 2 + count, partial, error = finals["c://x/T"] + assert count == 2 # only the first, successful flush's chunks + assert partial is True # flagged partial by the dropped second batch + assert error is not None and "RuntimeError" in error + + async def test_on_task_retry_resets_per_task_state(): # findings (8)/(9): a producer that raised after pumping 2 chunks (no EndOfTask) left # _deleted/_pending/_count behind. on_task_retry must clear them so the retry re-deletes @@ -378,3 +456,54 @@ async def test_on_task_retry_resets_per_task_state(): assert seen["c://x/T"] == (3, False) # counts ONLY the retry's chunks, not 2+3 assert milvus.delete_by_object.await_count == 2 # delete ran again on the retry's first chunk assert c._pending == {} + + +async def test_mark_job_cancelled_skips_future_flush_without_embedding(): + # `mfs job cancel` can't interrupt a batch_embed() call already in flight, but the NEXT + # flush for that job's chunks should skip the embed/upsert entirely (via the same + # _fail_batch path a real embed/Milvus failure uses) instead of burning embed time on + # work that will never be written -- and finalize the task failed, not wedge it. + c, embedder, milvus, tx = _consumer(batch_size=10, idle_ms=30) + finals: dict[str, tuple] = {} + c.register_on_succeeded( + lambda uri, job, count, partial, error: finals.update({uri: (count, partial, error)}) + ) + q = make_chunks_q(10) + c.start(q) + + c.mark_job_cancelled("job1") # T's job (default job="job1" in _chunk_env/_eot_env) + await q.put(_chunk_env("T", "a")) + await q.put(_chunk_env("T", "b")) + await q.put(_eot_env("T")) + await c.shutdown() + + embedder.batch_embed.assert_not_awaited() # never even tried to embed a cancelled job's chunks + milvus.upsert.assert_not_awaited() + count, partial, error = finals["c://x/T"] + assert count == 0 and partial is True and "job_cancelled" in error + assert c._pending == {} and c._task_errors == {} # released, not leaked + + +async def test_mark_job_cancelled_only_affects_that_job_not_others_in_same_batch(): + # Two tasks from different jobs land in the SAME flush; cancelling one job must not + # affect the other's chunks, which should embed/upsert/finalize normally. + c, embedder, milvus, tx = _consumer(batch_size=10, idle_ms=30) + finals: dict[str, tuple] = {} + c.register_on_succeeded( + lambda uri, job, count, partial, error: finals.update({uri: (job, count, partial, error)}) + ) + q = make_chunks_q(10) + c.start(q) + + c.mark_job_cancelled("job-cancelled") + await q.put(_chunk_env("T1", "a", job="job-cancelled")) + await q.put(_eot_env("T1", job="job-cancelled")) + await q.put(_chunk_env("T2", "b", job="job-active")) + await q.put(_eot_env("T2", job="job-active")) + await c.shutdown() + + assert finals["c://x/T1"] == ("job-cancelled", 0, True, "RuntimeError: job_cancelled") + assert finals["c://x/T2"] == ("job-active", 1, False, None) + milvus.upsert.assert_awaited_once() # only T2's row was ever written + (written_rows,), _ = milvus.upsert.call_args + assert len(written_rows) == 1 # T1's chunk never reached the embed/upsert call at all