Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 92 additions & 3 deletions hindsight-api-slim/hindsight_api/engine/memory_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -4205,7 +4205,10 @@ async def _search_with_retries(
if tracer:
tracer.start()

backend_acquire_start = time.time()
backend = await self._get_read_backend()
if tracer:
tracer.add_phase_metric("backend_acquisition", time.time() - backend_acquire_start)
recall_start = time.time()

# Buffer logs for clean output in concurrent scenarios.
Expand Down Expand Up @@ -4501,10 +4504,12 @@ def to_tuple_format(results):
},
)
# Also expose each retrieval method as its own phase so
# benchmarks can pinpoint which sub-query drives latency.
# benchmarks can pinpoint which sub-query drives latency. These are
# children of parallel_retrieval (marked diagnostic so the phase-coverage
# check doesn't double-count them).
for _method, _dur in aggregated_timings.items():
if _dur > 0:
tracer.add_phase_metric(f"retrieval_{_method}", _dur)
tracer.add_phase_metric(f"retrieval_{_method}", _dur, {"diagnostic": True})

# Step 3: Merge ranked lists. RRF by default; interleave (round-robin) when
# requested by consolidation dedup recall — RRF averages a strong-in-one-arm
Expand Down Expand Up @@ -4620,6 +4625,12 @@ def to_tuple_format(results):
# is_passthrough_reranker tells the scoring code to seed CE scores
# from RRF rank — only meaningful when the configured reranker is
# the slim/passthrough one that returns a constant score per pair.
#
# Timed separately from "reranking": the cross-encoder duration above
# (step_duration) is captured before this block runs, so the scoring
# math, additive boosts and final sort would otherwise be invisible in
# the phase metrics (issue #2361).
scoring_start = time.time()
if scored_results and reranking == "interleave":
# Interleave order is authoritative for dedup recall: do NOT re-sort by the
# recency/temporal boosts — that re-sort is precisely what buried the twin
Expand Down Expand Up @@ -4667,6 +4678,13 @@ def to_tuple_format(results):
step_duration,
{"reranker_type": rerank_kind, "candidates_reranked": len(scored_results)},
)
# Combined scoring + additive boosts + final sort, plus the trace
# serialization of reranked entries done just above.
tracer.add_phase_metric(
"combined_scoring",
time.time() - scoring_start,
{"candidates_scored": len(scored_results)},
)

# Cancellation checkpoint: reranking is done; skip the remaining
# enrichment (chunk/entity/source-fact fetches, each its own DB work)
Expand All @@ -4691,6 +4709,7 @@ def to_tuple_format(results):
if sr.retrieval.fact_type == "observation"
]
if observation_ids:
dedup_start = time.time()
superseded_ids: set[str] = set()
async with acquire_with_retry(backend) as dedup_conn:
obs_rows = await dedup_conn.fetch(
Expand All @@ -4701,6 +4720,12 @@ def to_tuple_format(results):
""",
observation_ids,
)
if tracer:
tracer.add_phase_metric(
"prefer_observations_dedup",
time.time() - dedup_start,
{"observations_considered": len(observation_ids)},
)
for obs_row in obs_rows:
for sid in obs_row["source_memory_ids"] or []:
superseded_ids.add(str(sid))
Expand All @@ -4725,6 +4750,7 @@ def to_tuple_format(results):
# Chunks are fetched independently of max_tokens filtering
chunks_dict = None
total_chunk_tokens = 0
chunk_fetch_start = time.time()
if include_chunks and top_scored:
from .response_models import ChunkInfo

Expand Down Expand Up @@ -4833,6 +4859,15 @@ def to_tuple_format(results):
)
total_chunk_tokens += chunk_tokens

# Chunk fetch involves up to two SQL round-trips plus per-chunk tiktoken
# encoding; record it only when chunks were actually requested (issue #2361).
if tracer and include_chunks:
tracer.add_phase_metric(
"chunk_fetch",
time.time() - chunk_fetch_start,
{"chunks_returned": len(chunks_dict or {}), "chunk_tokens": total_chunk_tokens},
)

# Step 6: Token budget filtering
step_start = time.time()

Expand All @@ -4856,6 +4891,10 @@ def to_tuple_format(results):
{"results_selected": len(top_scored), "tokens_used": total_tokens, "max_tokens": max_tokens},
)

# Record visits + build the JSON-serializable result dicts. Timed as one
# phase: the visit loop alone walks every scored result (issue #2361).
assembly_start = time.time()

# Record visits for all retrieved nodes
if tracer:
for sr in scored_results:
Expand Down Expand Up @@ -4905,7 +4944,15 @@ def to_tuple_format(results):
)
top_results_dicts.append(result_dict)

if tracer:
tracer.add_phase_metric(
"result_serialization",
time.time() - assembly_start,
{"results_serialized": len(top_results_dicts)},
)

# Fetch source facts for observation-type results (mirrors chunks pattern)
source_fact_start = time.time()
source_fact_ids_by_obs: dict[str, list[str]] = {} # obs_id -> [source_id, ...]
source_facts_dict: dict[str, MemoryFact] | None = None
if include_source_facts:
Expand Down Expand Up @@ -4998,6 +5045,18 @@ def _make_source_fact(sid: str, r: Any) -> MemoryFact:
source_facts_dict[sid] = _make_source_fact(sid, r)
total_source_tokens += fact_tokens

# Source-fact enrichment is two SQL passes + tiktoken encoding; record it
# only when requested (issue #2361).
if tracer and include_source_facts:
tracer.add_phase_metric(
"source_fact_fetch",
time.time() - source_fact_start,
{"source_facts_returned": len(source_facts_dict or {})},
)

# entity fetch + MemoryFact construction + entity-state build, timed together.
entity_build_start = time.time()

# Get entities for each fact if include_entities is requested.
# _entity_rows_for_units_sql resolves both direct unit_entities rows
# and observation-via-source-memory inheritance in a single query.
Expand Down Expand Up @@ -5070,11 +5129,41 @@ def _make_source_fact(sid: str, r: Any) -> MemoryFact:
observations=[], # Mental models provide this now
)

# Finalize trace if enabled
if tracer:
tracer.add_phase_metric(
"entity_build",
time.time() - entity_build_start,
{"entities_returned": len(entities_dict or {})},
)

# Diagnostic phases — these do NOT partition the timeline and are
# excluded from the phase-coverage check (see test_trace_phase_coverage):
# the pool waits overlap other phases (semaphore_wait precedes the
# tracer window; connection_wait is part of parallel_retrieval), and the
# per-method retrieval splits are children of parallel_retrieval.
if semaphore_wait > 0:
tracer.add_phase_metric("semaphore_wait", semaphore_wait, {"diagnostic": True})
if max_conn_wait > 0:
tracer.add_phase_metric("connection_wait", max_conn_wait, {"diagnostic": True})

# Finalize trace if enabled. finalize() snapshots total_duration_seconds at
# entry, so its own object construction + to_dict() serialization fall outside
# that total; we still surface the cost as a diagnostic phase (issue #2361).
trace_dict = None
if tracer:
from .search.trace import SearchPhaseMetrics

finalize_start = time.time()
trace = tracer.finalize(top_results_dicts)
trace_dict = trace.to_dict() if trace else None
if trace_dict is not None:
trace_dict["summary"]["phase_metrics"].append(
SearchPhaseMetrics(
phase_name="trace_finalize",
duration_seconds=time.time() - finalize_start,
details={"diagnostic": True},
).model_dump()
)

# Log final recall stats
total_time = time.time() - recall_start
Expand Down
80 changes: 80 additions & 0 deletions hindsight-api-slim/tests/test_search_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,86 @@ async def test_search_with_trace(memory, request_context):
await memory.delete_bank(bank_id, request_context=request_context)


@pytest.mark.asyncio
async def test_trace_phase_coverage(memory, request_context):
"""Phase metrics should account for (nearly) all of total_duration_seconds.

Regression guard for issue #2361: before the fix the named phases summed to
~10-15% of total because backend acquisition, combined scoring, chunk/source/
entity enrichment and result serialization were un-instrumented.

The per-method ``retrieval_*`` splits and the pool-wait / finalize phases are
flagged ``details.diagnostic`` because they overlap (or sit outside) the
timeline; only the non-diagnostic phases partition it, so only those are summed.
"""
bank_id = f"test_trace_cov_{datetime.now(timezone.utc).timestamp()}"

try:
await memory.retain_async(
bank_id=bank_id,
content="Alice works at Google in Mountain View and joined in 2019.",
context="test context",
request_context=request_context,
)
await memory.retain_async(
bank_id=bank_id,
content="Bob also works at Google but in New York, on the search team.",
context="test context",
request_context=request_context,
)

# Exercise the enrichment paths (chunks + entities) so their phases fire.
search_result = await memory.recall_async(
bank_id=bank_id,
query="Who works at Google?",
fact_type=["world"],
budget=Budget.LOW,
max_tokens=512,
enable_trace=True,
include_chunks=True,
include_entities=True,
request_context=request_context,
)

trace = search_result.trace
assert trace is not None
phase_metrics = trace["summary"]["phase_metrics"]
total = trace["summary"]["total_duration_seconds"]

# The blocks the issue called out must now each have a phase.
phase_names = {pm["phase_name"] for pm in phase_metrics}
for expected in (
"backend_acquisition",
"combined_scoring",
"chunk_fetch",
"result_serialization",
"entity_build",
):
assert expected in phase_names, f"missing phase metric: {expected}"

# Non-diagnostic phases partition the timeline; sum and compare to total.
timeline = [pm for pm in phase_metrics if not pm["details"].get("diagnostic")]
timeline_sum = sum(pm["duration_seconds"] for pm in timeline)

# No phase double-counts: the named partition never exceeds the wall clock.
assert timeline_sum <= total + 0.01, (
f"phase sum {timeline_sum:.4f}s exceeds total {total:.4f}s — likely double counting"
)
# Coverage: only tiny sync gaps between phases should be unaccounted.
# (finalize/to_dict serialization runs after the total snapshot, so it cannot
# be inside this total — it is reported as the diagnostic trace_finalize phase.)
gap = total - timeline_sum
assert gap <= 0.15 * total + 0.02, (
f"unaccounted recall time {gap:.4f}s of {total:.4f}s total — phases sum to "
f"only {timeline_sum / total:.0%}; an un-instrumented block likely regressed"
)

print(f"\n✓ Phase coverage: {timeline_sum / total:.0%} of {total:.3f}s accounted")

finally:
await memory.delete_bank(bank_id, request_context=request_context)


@pytest.mark.asyncio
async def test_search_without_trace(memory, request_context):
"""Test that search with enable_trace=False returns None for trace."""
Expand Down
Loading