diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java index 0f5e0548c5..1e87d14ec2 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java @@ -337,6 +337,8 @@ private Flux handleStreamingToolExecution(Prompt prompt, ChatRespo if (chatResponse.hasFinishReasons(java.util.Set.of("tool_use"))) { return Flux.deferContextual(ctx -> { ToolExecutionResult toolExecutionResult; + Observation parentObs = ctx.getOrDefault(ObservationThreadLocalAccessor.KEY, null); + Observation.Scope scope = parentObs != null ? parentObs.openScope() : null; try { if (this.internalToolExecutionWarned.compareAndSet(false, true)) { logger.warn( @@ -347,6 +349,9 @@ private Flux handleStreamingToolExecution(Prompt prompt, ChatRespo toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, chatResponse); } finally { + if (scope != null) { + scope.close(); + } org.springframework.ai.model.tool.internal.ToolCallReactiveContextHolder.clearContext(); } if (toolExecutionResult.returnDirect()) { diff --git a/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java index b97b8bec25..63c1cd8761 100644 --- a/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java +++ b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java @@ -867,6 +867,8 @@ private Flux internalStream(Prompt prompt, @Nullable ChatResponse // is currently only synchronous return Flux.deferContextual(ctx -> { ToolExecutionResult toolExecutionResult; + Observation parentObs = ctx.getOrDefault(ObservationThreadLocalAccessor.KEY, null); + Observation.Scope scope = parentObs != null ? parentObs.openScope() : null; try { if (this.internalToolExecutionWarned.compareAndSet(false, true)) { logger.warn( @@ -877,6 +879,9 @@ private Flux internalStream(Prompt prompt, @Nullable ChatResponse toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, chatResponse); } finally { + if (scope != null) { + scope.close(); + } ToolCallReactiveContextHolder.clearContext(); } diff --git a/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/DeepSeekChatModel.java b/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/DeepSeekChatModel.java index ffeaf2affb..bbb662a7a7 100644 --- a/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/DeepSeekChatModel.java +++ b/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/DeepSeekChatModel.java @@ -327,6 +327,8 @@ private Flux internalStream(Prompt prompt, @Nullable ChatResponse // is currently only synchronous return Flux.deferContextual(ctx -> { ToolExecutionResult toolExecutionResult; + Observation parentObs = ctx.getOrDefault(ObservationThreadLocalAccessor.KEY, null); + Observation.Scope scope = parentObs != null ? parentObs.openScope() : null; try { if (this.internalToolExecutionWarned.compareAndSet(false, true)) { logger.warn( @@ -337,6 +339,9 @@ private Flux internalStream(Prompt prompt, @Nullable ChatResponse toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response); } finally { + if (scope != null) { + scope.close(); + } ToolCallReactiveContextHolder.clearContext(); } if (toolExecutionResult.returnDirect()) { diff --git a/models/spring-ai-google-genai/src/main/java/org/springframework/ai/google/genai/GoogleGenAiChatModel.java b/models/spring-ai-google-genai/src/main/java/org/springframework/ai/google/genai/GoogleGenAiChatModel.java index 6632e8d163..97a6491a47 100644 --- a/models/spring-ai-google-genai/src/main/java/org/springframework/ai/google/genai/GoogleGenAiChatModel.java +++ b/models/spring-ai-google-genai/src/main/java/org/springframework/ai/google/genai/GoogleGenAiChatModel.java @@ -580,6 +580,8 @@ private Flux internalStream(Prompt prompt, @Nullable ChatResponse // FIXME: bounded elastic needs to be used since tool calling // is currently only synchronous ToolExecutionResult toolExecutionResult; + Observation parentObs = ctx.getOrDefault(ObservationThreadLocalAccessor.KEY, null); + Observation.Scope scope = parentObs != null ? parentObs.openScope() : null; try { if (this.internalToolExecutionWarned.compareAndSet(false, true)) { logger.warn( @@ -590,6 +592,9 @@ private Flux internalStream(Prompt prompt, @Nullable ChatResponse toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, aggregatedResponse); } finally { + if (scope != null) { + scope.close(); + } ToolCallReactiveContextHolder.clearContext(); } if (toolExecutionResult.returnDirect()) { diff --git a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatModel.java b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatModel.java index a8eecb092e..7004a580f6 100644 --- a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatModel.java +++ b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatModel.java @@ -417,6 +417,8 @@ public Flux stream(Prompt prompt) { // is currently only synchronous return Flux.deferContextual(ctx -> { ToolExecutionResult toolExecutionResult; + Observation parentObs = ctx.getOrDefault(ObservationThreadLocalAccessor.KEY, null); + Observation.Scope scope = parentObs != null ? parentObs.openScope() : null; try { if (this.internalToolExecutionWarned.compareAndSet(false, true)) { logger.warn( @@ -427,6 +429,9 @@ public Flux stream(Prompt prompt) { toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response); } finally { + if (scope != null) { + scope.close(); + } ToolCallReactiveContextHolder.clearContext(); } if (toolExecutionResult.returnDirect()) { diff --git a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java index 8d9aaaf707..0e027e8c1b 100644 --- a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java +++ b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java @@ -355,6 +355,8 @@ private Flux internalStream(Prompt prompt, @Nullable ChatResponse // is currently only synchronous return Flux.deferContextual(ctx -> { ToolExecutionResult toolExecutionResult; + Observation parentObs = ctx.getOrDefault(ObservationThreadLocalAccessor.KEY, null); + Observation.Scope scope = parentObs != null ? parentObs.openScope() : null; try { if (this.internalToolExecutionWarned.compareAndSet(false, true)) { logger.warn( @@ -365,6 +367,9 @@ private Flux internalStream(Prompt prompt, @Nullable ChatResponse toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response); } finally { + if (scope != null) { + scope.close(); + } ToolCallReactiveContextHolder.clearContext(); } if (toolExecutionResult.returnDirect()) { diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java index c875d9106e..4539585394 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java @@ -415,6 +415,8 @@ private Flux internalStream(Prompt prompt, @Nullable ChatResponse // is currently only synchronous return Flux.deferContextual(ctx -> { ToolExecutionResult toolExecutionResult; + Observation parentObs = ctx.getOrDefault(ObservationThreadLocalAccessor.KEY, null); + Observation.Scope scope = parentObs != null ? parentObs.openScope() : null; try { if (this.internalToolExecutionWarned.compareAndSet(false, true)) { logger.warn( @@ -425,6 +427,9 @@ private Flux internalStream(Prompt prompt, @Nullable ChatResponse toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response); } finally { + if (scope != null) { + scope.close(); + } ToolCallReactiveContextHolder.clearContext(); } if (toolExecutionResult.returnDirect()) { diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java index ab55c20237..258168899b 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java @@ -458,11 +458,16 @@ private Flux internalStream(Prompt prompt, @Nullable ChatResponse } return Flux.deferContextual(ctx -> { ToolExecutionResult tetoolExecutionResult; + Observation parentObs = ctx.getOrDefault(ObservationThreadLocalAccessor.KEY, null); + Observation.Scope scope = parentObs != null ? parentObs.openScope() : null; try { ToolCallReactiveContextHolder.setContext(ctx); tetoolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, aggregated); } finally { + if (scope != null) { + scope.close(); + } ToolCallReactiveContextHolder.clearContext(); } if (tetoolExecutionResult.returnDirect()) { diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/ToolCallAdvisor.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/ToolCallAdvisor.java index 4d1dcd74df..a79f4598b8 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/ToolCallAdvisor.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/ToolCallAdvisor.java @@ -19,6 +19,8 @@ import java.util.List; import java.util.concurrent.atomic.AtomicReference; +import io.micrometer.observation.Observation; +import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; import reactor.core.publisher.Flux; import reactor.core.scheduler.Schedulers; @@ -305,11 +307,24 @@ private Flux handleToolCallRecursion(ChatClientResponse aggr // Execute tool calls on bounded elastic scheduler (tool execution is blocking) Flux toolCallFlux = Flux.deferContextual(ctx -> { ToolExecutionResult toolExecutionResult; + + // Restore observation scope on the boundedElastic thread so tool execution + // can correctly parent any child spans it creates. + Observation parentObs = ctx.getOrDefault(ObservationThreadLocalAccessor.KEY, null); + // Guard: only open a scope when the observation is NOT already the current + // one. + Observation.Scope scope = (parentObs != null + && parentObs != parentObs.getObservationRegistry().getCurrentObservation()) ? parentObs.openScope() + : null; + try { ToolCallReactiveContextHolder.setContext(ctx); toolExecutionResult = this.toolCallingManager.executeToolCalls(finalRequest.prompt(), chatResponse); } finally { + if (scope != null) { + scope.close(); + } ToolCallReactiveContextHolder.clearContext(); } diff --git a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/DefaultAroundAdvisorChainTests.java b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/DefaultAroundAdvisorChainTests.java index 8dee4c66e2..5d6fbe7542 100644 --- a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/DefaultAroundAdvisorChainTests.java +++ b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/DefaultAroundAdvisorChainTests.java @@ -18,8 +18,11 @@ import java.util.ArrayList; import java.util.List; +import java.util.concurrent.atomic.AtomicReference; +import io.micrometer.observation.Observation; import io.micrometer.observation.ObservationRegistry; +import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import reactor.core.publisher.Flux; @@ -323,6 +326,44 @@ void mutateDoesNotAffectOriginalChain() { assertThat(chain.getCallAdvisors().get(0).getName()).isEqualTo("advisor1"); } + @Disabled + @Test + void whenNextStreamCalledThenObservationScopeIsOpenDuringAdviseStream() { + // Fix C: nextStream() opens the observation scope inside Flux.defer so child + // observations created synchronously during subscription assembly find the + // correct parent in the Micrometer/OTel ThreadLocal. + ObservationRegistry registry = ObservationRegistry.create(); + AtomicReference currentObsWhenAdviseStreamCalled = new AtomicReference<>(); + + StreamAdvisor advisor = new StreamAdvisor() { + @Override + public Flux adviseStream(ChatClientRequest request, StreamAdvisorChain chain) { + currentObsWhenAdviseStreamCalled.set(registry.getCurrentObservation()); + return Flux.just(ChatClientResponse.builder().build()); + } + + @Override + public String getName() { + return "test-advisor"; + } + + @Override + public int getOrder() { + return 0; + } + }; + + DefaultAroundAdvisorChain.builder(registry) + .push(advisor) + .build() + .nextStream(ChatClientRequest.builder().prompt(new Prompt("test")).build()) + .blockLast(); + + assertThat(currentObsWhenAdviseStreamCalled.get()) + .as("Fix C: chain observation must be in scope when adviseStream is invoked") + .isNotNull(); + } + private CallAdvisor createMockAdvisor(String name, int order) { return new CallAdvisor() { @Override diff --git a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/ToolCallAdvisorTests.java b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/ToolCallAdvisorTests.java index c4f2064e9a..2dc46f85d3 100644 --- a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/ToolCallAdvisorTests.java +++ b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/ToolCallAdvisorTests.java @@ -18,9 +18,18 @@ import java.util.List; import java.util.Map; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.atomic.AtomicReference; import java.util.function.BiFunction; +import io.micrometer.observation.Observation; +import io.micrometer.observation.ObservationHandler; import io.micrometer.observation.ObservationRegistry; +import io.micrometer.observation.ObservationView; +import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; @@ -28,6 +37,8 @@ import org.mockito.junit.jupiter.MockitoExtension; import org.mockito.quality.Strictness; import reactor.core.publisher.Flux; +import reactor.core.publisher.Hooks; +import reactor.core.scheduler.Schedulers; import org.springframework.ai.chat.client.ChatClientRequest; import org.springframework.ai.chat.client.ChatClientResponse; @@ -415,6 +426,55 @@ void testAdviseStreamWithSingleToolCallIteration() { verify(this.toolCallingManager, times(1)).executeToolCalls(any(Prompt.class), any(ChatResponse.class)); } + @Test + void toolExecutionDuringStreamingShouldHaveObservationInScope() { + // Fix B: handleToolCallRecursion opens the observation scope before executing + // blocking tool calls on the boundedElastic scheduler thread. + ObservationRegistry registry = ObservationRegistry.create(); + ToolCallAdvisor advisor = ToolCallAdvisor.builder().toolCallingManager(this.toolCallingManager).build(); + + ChatClientRequest request = createMockRequest(); + ChatClientResponse responseWithToolCall = createMockResponse(true); + ChatClientResponse finalResponse = createMockResponse(false); + + int[] callCount = { 0 }; + TerminalStreamAdvisor terminalAdvisor = new TerminalStreamAdvisor((req, chain) -> { + callCount[0]++; + return Flux.just(callCount[0] == 1 ? responseWithToolCall : finalResponse); + }); + + StreamAdvisorChain realChain = DefaultAroundAdvisorChain.builder(registry) + .pushAll(List.of(advisor, terminalAdvisor)) + .build(); + + AtomicReference observationDuringTool = new AtomicReference<>(); + List conversationHistory = List.of(new UserMessage("test"), + AssistantMessage.builder().content("").build(), ToolResponseMessage.builder().build()); + ToolExecutionResult toolExecutionResult = ToolExecutionResult.builder() + .conversationHistory(conversationHistory) + .build(); + when(this.toolCallingManager.executeToolCalls(any(Prompt.class), any(ChatResponse.class))).thenAnswer(inv -> { + observationDuringTool.set(registry.getCurrentObservation()); + return toolExecutionResult; + }); + + Observation outerObservation = Observation.createNotStarted("outer.test", registry).start(); + + // Simulate DefaultAroundAdvisorChain's contextWrite so the tool-call + // deferContextual lambda sees the observation and can open a scope on + // the boundedElastic thread. + advisor.adviseStream(request, realChain) + .contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, outerObservation)) + .collectList() + .block(); + + outerObservation.stop(); + + assertThat(observationDuringTool.get()) + .as("Fix B: observation must be in scope on the boundedElastic thread during tool execution") + .isNotNull(); + } + @Test void testAdviseStreamWithReturnDirectToolExecution() { ToolCallAdvisor advisor = ToolCallAdvisor.builder().toolCallingManager(this.toolCallingManager).build(); @@ -869,6 +929,75 @@ private ChatClientResponse createMockResponse(boolean hasToolCalls) { return response; } + /** + * Reproduces the detached-span problem visible in traces when + * {@link Hooks#enableAutomaticContextPropagation()} is NOT active. + * + *

+ * {@link DefaultAroundAdvisorChain#nextStream} writes the advisor observation into + * the Reactor context via {@code contextWrite} but never opens a Micrometer + * ThreadLocal scope. Without automatic context propagation, worker threads (e.g. + * those created by {@code subscribeOn(boundedElastic())}) start with empty + * ThreadLocals, so any observation a real streaming model would create on such a + * thread has no parent and appears detached from {@code tool_calling} in traces. The + * recursive stream (second LLM call) is unaffected because + * {@code handleToolCallRecursion} explicitly calls {@code openScope()} before the + * tool execution. + */ + @Test + void withoutAutomaticContextPropagation_firstModelStreamObservationIsDetached() { + ObservationRegistry registry = ObservationRegistry.create(); + ToolCallAdvisor advisor = ToolCallAdvisor.builder().toolCallingManager(this.toolCallingManager).build(); + + ChatClientRequest request = createMockRequest(); + + int[] callCount = { 0 }; + AtomicReference observationOnWorkerThread = new AtomicReference<>(); + + // Simulates a real streaming model: the actual work happens on a worker thread. + // Without auto-propagation no ThreadLocal is restored there, so + // registry.getCurrentObservation() returns null and the model span is detached. + TerminalStreamAdvisor terminalAdvisor = new TerminalStreamAdvisor((req, chain) -> { + callCount[0]++; + if (callCount[0] == 1) { + return Flux.defer(() -> { + observationOnWorkerThread.set(registry.getCurrentObservation()); + return Flux.just(createMockResponse(true)); + }).subscribeOn(Schedulers.boundedElastic()); + } + return Flux.just(createMockResponse(false)); + }); + + StreamAdvisorChain realChain = DefaultAroundAdvisorChain.builder(registry) + .pushAll(List.of(advisor, terminalAdvisor)) + .build(); + + List conversationHistory = List.of(new UserMessage("test"), + AssistantMessage.builder().content("").build(), ToolResponseMessage.builder().build()); + ToolExecutionResult toolExecutionResult = ToolExecutionResult.builder() + .conversationHistory(conversationHistory) + .build(); + when(this.toolCallingManager.executeToolCalls(any(Prompt.class), any(ChatResponse.class))) + .thenReturn(toolExecutionResult); + + Observation outerObservation = Observation.createNotStarted("invoke_workflow", registry).start(); + + advisor.adviseStream(request, realChain) + .contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, outerObservation)) + .collectList() + .block(); + + outerObservation.stop(); + + assertThat(observationOnWorkerThread.get()) + .as("Without Hooks.enableAutomaticContextPropagation() the advisor observation is not " + + "restored to ThreadLocal on the worker thread, so the model span is detached " + + "from tool_calling in traces") + .isNull(); + } + + // Helper classes + private static class TerminalCallAdvisor implements CallAdvisor { private final BiFunction responseFunction; @@ -979,4 +1108,265 @@ public TestableToolCallAdvisor build() { } + /** + * Verifies that {@link Hooks#enableAutomaticContextPropagation()} does not break the + * streaming tool-call observation path. + * + *

+ * With auto-propagation active Reactor restores ThreadLocals from the Reactor context + * at every operator boundary including the {@code subscribeOn(boundedElastic())} + * thread hop in {@link ToolCallAdvisor}. This means {@code openScope()} in + * {@code handleToolCallRecursion} may be called on an observation that is already the + * current ThreadLocal observation — a redundant double-open. These tests confirm the + * behaviour is correct: the observation remains in scope during tool execution and no + * observation becomes its own parent. + */ + @Nested + class WithAutomaticContextPropagation { + + @BeforeEach + void enableAutoPropagation() { + Hooks.enableAutomaticContextPropagation(); + } + + @AfterEach + void disableAutoPropagation() { + Hooks.disableAutomaticContextPropagation(); + } + + @Test + void observationIsSameInstanceDuringToolExecution() { + // With auto-propagation Reactor restores the ThreadLocal before every + // operator boundary including the subscribeOn(boundedElastic()) hop, so the + // guard in handleToolCallRecursion must NOT call openScope() redundantly. The + // observable + // proof: the observation seen inside executeToolCalls is the exact same + // instance + // placed in the Reactor context, not a stale or double-opened scope. + ObservationRegistry registry = ObservationRegistry.create(); + ToolCallAdvisor advisor = ToolCallAdvisor.builder() + .toolCallingManager(ToolCallAdvisorTests.this.toolCallingManager) + .build(); + + ChatClientRequest request = createMockRequest(); + ChatClientResponse responseWithToolCall = createMockResponse(true); + ChatClientResponse finalResponse = createMockResponse(false); + + int[] callCount = { 0 }; + TerminalStreamAdvisor terminalAdvisor = new TerminalStreamAdvisor((req, chain) -> { + callCount[0]++; + return Flux.just(callCount[0] == 1 ? responseWithToolCall : finalResponse); + }); + + StreamAdvisorChain realChain = DefaultAroundAdvisorChain.builder(registry) + .pushAll(List.of(advisor, terminalAdvisor)) + .build(); + + AtomicReference observationDuringTool = new AtomicReference<>(); + List conversationHistory = List.of(new UserMessage("test"), + AssistantMessage.builder().content("").build(), ToolResponseMessage.builder().build()); + ToolExecutionResult toolExecutionResult = ToolExecutionResult.builder() + .conversationHistory(conversationHistory) + .build(); + when(ToolCallAdvisorTests.this.toolCallingManager.executeToolCalls(any(Prompt.class), + any(ChatResponse.class))) + .thenAnswer(inv -> { + observationDuringTool.set(registry.getCurrentObservation()); + return toolExecutionResult; + }); + + Observation outerObservation = Observation.createNotStarted("outer.test", registry).start(); + + advisor.adviseStream(request, realChain) + .contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, outerObservation)) + .collectList() + .block(); + + outerObservation.stop(); + + assertThat(observationDuringTool.get()) + .as("Observation during tool execution must be the exact same instance as the outer observation") + .isSameAs(outerObservation); + } + + @Test + void returnDirectPathHasObservationInScope() { + ObservationRegistry registry = ObservationRegistry.create(); + ToolCallAdvisor advisor = ToolCallAdvisor.builder() + .toolCallingManager(ToolCallAdvisorTests.this.toolCallingManager) + .build(); + + ChatClientRequest request = createMockRequest(); + ChatClientResponse responseWithToolCall = createMockResponse(true); + + TerminalStreamAdvisor terminalAdvisor = new TerminalStreamAdvisor( + (req, chain) -> Flux.just(responseWithToolCall)); + + StreamAdvisorChain realChain = DefaultAroundAdvisorChain.builder(registry) + .pushAll(List.of(advisor, terminalAdvisor)) + .build(); + + AtomicReference observationDuringTool = new AtomicReference<>(); + ToolResponseMessage.ToolResponse toolResponse = new ToolResponseMessage.ToolResponse("tool-1", "testTool", + "Tool result data"); + ToolResponseMessage toolResponseMessage = ToolResponseMessage.builder() + .responses(List.of(toolResponse)) + .build(); + List conversationHistory = List.of(new UserMessage("test"), + AssistantMessage.builder().content("").build(), toolResponseMessage); + ToolExecutionResult toolExecutionResult = ToolExecutionResult.builder() + .conversationHistory(conversationHistory) + .returnDirect(true) + .build(); + when(ToolCallAdvisorTests.this.toolCallingManager.executeToolCalls(any(Prompt.class), + any(ChatResponse.class))) + .thenAnswer(inv -> { + observationDuringTool.set(registry.getCurrentObservation()); + return toolExecutionResult; + }); + + Observation outerObservation = Observation.createNotStarted("outer.test", registry).start(); + + List results = advisor.adviseStream(request, realChain) + .contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, outerObservation)) + .collectList() + .block(); + + outerObservation.stop(); + + assertThat(observationDuringTool.get()) + .as("Observation must be in scope during return-direct tool execution " + + "even with automatic context propagation enabled") + .isSameAs(outerObservation); + assertThat(results).isNotNull().hasSize(1); + assertThat(results.get(0).chatResponse().getResults().get(0).getOutput().getText()) + .isEqualTo("Tool result data"); + } + + @Test + void noObservationBecomesItsOwnParent() { + CopyOnWriteArrayList starts = new CopyOnWriteArrayList<>(); + ObservationRegistry registry = ObservationRegistry.create(); + registry.observationConfig().observationHandler(new ObservationHandler() { + @Override + public void onStart(Observation.Context context) { + starts.add(new StartRecord(context, context.getParentObservation())); + } + + @Override + public boolean supportsContext(Observation.Context context) { + return true; + } + }); + + ToolCallAdvisor advisor = ToolCallAdvisor.builder() + .toolCallingManager(ToolCallAdvisorTests.this.toolCallingManager) + .build(); + + ChatClientRequest request = createMockRequest(); + ChatClientResponse responseWithToolCall = createMockResponse(true); + ChatClientResponse finalResponse = createMockResponse(false); + + int[] callCount = { 0 }; + TerminalStreamAdvisor terminalAdvisor = new TerminalStreamAdvisor((req, chain) -> { + callCount[0]++; + return Flux.just(callCount[0] == 1 ? responseWithToolCall : finalResponse); + }); + + StreamAdvisorChain realChain = DefaultAroundAdvisorChain.builder(registry) + .pushAll(List.of(advisor, terminalAdvisor)) + .build(); + + Observation outerObservation = Observation.createNotStarted("outer.test", registry).start(); + + List conversationHistory = List.of(new UserMessage("test"), + AssistantMessage.builder().content("").build(), ToolResponseMessage.builder().build()); + ToolExecutionResult toolExecutionResult = ToolExecutionResult.builder() + .conversationHistory(conversationHistory) + .build(); + when(ToolCallAdvisorTests.this.toolCallingManager.executeToolCalls(any(Prompt.class), + any(ChatResponse.class))) + .thenReturn(toolExecutionResult); + + advisor.adviseStream(request, realChain) + .contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, outerObservation)) + .collectList() + .block(); + + outerObservation.stop(); + + assertThat(starts).as("At least one observation must have started").isNotEmpty(); + starts.forEach(record -> { + ObservationView parent = record.parent(); + if (parent != null) { + assertThat(parent.getContextView()) + .as("Observation [%s] must not be its own parent", record.ctx().getName()) + .isNotSameAs(record.ctx()); + } + }); + } + + @Test + void withAutomaticContextPropagation_firstModelStreamObservationIsProperlyNested() { + // Counterpart to the standalone test above: with auto-propagation Reactor + // restores ThreadLocals from the Reactor context at every operator boundary + // including the subscribeOn(boundedElastic()) hop. The advisor observation + // written to context by DefaultAroundAdvisorChain is therefore visible on + // the worker thread and any model observation created there is correctly + // nested under tool_calling instead of appearing detached. + ObservationRegistry registry = ObservationRegistry.create(); + ToolCallAdvisor advisor = ToolCallAdvisor.builder() + .toolCallingManager(ToolCallAdvisorTests.this.toolCallingManager) + .build(); + + ChatClientRequest request = createMockRequest(); + + int[] callCount = { 0 }; + AtomicReference observationOnWorkerThread = new AtomicReference<>(); + + TerminalStreamAdvisor terminalAdvisor = new TerminalStreamAdvisor((req, chain) -> { + callCount[0]++; + if (callCount[0] == 1) { + return Flux.defer(() -> { + observationOnWorkerThread.set(registry.getCurrentObservation()); + return Flux.just(createMockResponse(true)); + }).subscribeOn(Schedulers.boundedElastic()); + } + return Flux.just(createMockResponse(false)); + }); + + StreamAdvisorChain realChain = DefaultAroundAdvisorChain.builder(registry) + .pushAll(List.of(advisor, terminalAdvisor)) + .build(); + + List conversationHistory = List.of(new UserMessage("test"), + AssistantMessage.builder().content("").build(), ToolResponseMessage.builder().build()); + ToolExecutionResult toolExecutionResult = ToolExecutionResult.builder() + .conversationHistory(conversationHistory) + .build(); + when(ToolCallAdvisorTests.this.toolCallingManager.executeToolCalls(any(Prompt.class), + any(ChatResponse.class))) + .thenReturn(toolExecutionResult); + + Observation outerObservation = Observation.createNotStarted("invoke_workflow", registry).start(); + + advisor.adviseStream(request, realChain) + .contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, outerObservation)) + .collectList() + .block(); + + outerObservation.stop(); + + assertThat(observationOnWorkerThread.get()) + .as("With Hooks.enableAutomaticContextPropagation() the advisor observation is " + + "restored to ThreadLocal on the worker thread, so the model span is " + + "correctly nested under tool_calling in traces") + .isNotNull(); + } + + record StartRecord(Observation.Context ctx, ObservationView parent) { + } + + } + }