diff --git a/shared/ska_utils/src/ska_utils/__init__.py b/shared/ska_utils/src/ska_utils/__init__.py index 715912a6a..a8bbdac84 100644 --- a/shared/ska_utils/src/ska_utils/__init__.py +++ b/shared/ska_utils/src/ska_utils/__init__.py @@ -18,6 +18,7 @@ from .telemetry import ( TA_OTEL_ENDPOINT as TA_OTEL_ENDPOINT, TA_TELEMETRY_ENABLED as TA_TELEMETRY_ENABLED, + AgentTelemetryLogger as AgentTelemetryLogger, Telemetry as Telemetry, get_telemetry as get_telemetry, initialize_telemetry as initialize_telemetry, diff --git a/shared/ska_utils/src/ska_utils/telemetry.py b/shared/ska_utils/src/ska_utils/telemetry.py index 955a893e7..4b4cd497a 100644 --- a/shared/ska_utils/src/ska_utils/telemetry.py +++ b/shared/ska_utils/src/ska_utils/telemetry.py @@ -1,4 +1,7 @@ +import json import logging +from contextlib import contextmanager +from typing import Any from opentelemetry import trace from opentelemetry._logs import set_logger_provider @@ -173,6 +176,226 @@ def _enable_metrics(self) -> None: set_meter_provider(meter_provider) +class AgentTelemetryLogger: + """Provides standardized structured telemetry logging and span enrichment + for agent invocations. + + Captures metadata including: agent name, model used, tool calls, + tool call count, reasoning/thinking, user ISID, internal function calls, + and token usage. + + Log output follows a standardized JSON-like format: + { + "agent.name": "weather agent", + "agent.model": "gpt-4o", + "agent.tool_calls": ["get_weather", "get_location"], + "agent.tool_call_count": 2, + "agent.reasoning": "...", + "agent.user_isid": "user123", + ... + } + """ + + def __init__( + self, + agent_name: str, + model_name: str, + user_isid: str | None = None, + telemetry: "Telemetry | None" = None, + ): + self._telemetry = telemetry + self._logger = logging.getLogger(f"agent_telemetry.{agent_name}") + self._agent_name = agent_name + self._model_name = model_name + self._user_isid = user_isid + self._tool_calls: list[str] = [] + self._internal_function_calls: list[str] = [] + self._reasoning_entries: list[str] = [] + self._invocation_count: int = 0 + + @property + def agent_name(self) -> str: + return self._agent_name + + @property + def model_name(self) -> str: + return self._model_name + + @property + def user_isid(self) -> str | None: + return self._user_isid + + @property + def tool_calls(self) -> list[str]: + return list(self._tool_calls) + + @property + def tool_call_count(self) -> int: + return len(self._tool_calls) + + @property + def internal_function_calls(self) -> list[str]: + return list(self._internal_function_calls) + + @property + def reasoning_entries(self) -> list[str]: + return list(self._reasoning_entries) + + @property + def invocation_count(self) -> int: + return self._invocation_count + + def record_tool_call(self, tool_name: str) -> None: + """Record a tool/plugin call made by the agent.""" + self._tool_calls.append(tool_name) + + def record_tool_calls(self, tool_names: list[str]) -> None: + """Record multiple tool/plugin calls made by the agent.""" + self._tool_calls.extend(tool_names) + + def record_internal_function_call(self, function_name: str) -> None: + """Record an internal function call (kernel function invocation).""" + self._internal_function_calls.append(function_name) + + def record_reasoning(self, reasoning: str) -> None: + """Record a reasoning/thinking step performed by the agent.""" + if reasoning: + self._reasoning_entries.append(reasoning) + + def record_invocation(self) -> None: + """Increment the agent invocation counter.""" + self._invocation_count += 1 + + def get_standardized_log( + self, + session_id: str | None = None, + request_id: str | None = None, + completion_tokens: int = 0, + prompt_tokens: int = 0, + total_tokens: int = 0, + ) -> dict[str, Any]: + """Return the standardized metadata dict for structured logging. + + Returns a dict with keys following the ``agent.*`` namespace: + agent.name, agent.model, agent.tool_calls, agent.tool_call_count, + agent.internal_function_calls, agent.internal_function_call_count, + agent.reasoning, agent.user_isid, agent.invocation_count, + agent.session_id, agent.request_id, agent.completion_tokens, + agent.prompt_tokens, agent.total_tokens. + """ + log_data: dict[str, Any] = { + "agent.name": self._agent_name, + "agent.model": self._model_name, + "agent.tool_calls": list(self._tool_calls), + "agent.tool_call_count": self.tool_call_count, + "agent.internal_function_calls": list(self._internal_function_calls), + "agent.internal_function_call_count": len(self._internal_function_calls), + "agent.reasoning": list(self._reasoning_entries), + "agent.user_isid": self._user_isid or "", + "agent.invocation_count": self._invocation_count, + "agent.session_id": session_id or "", + "agent.request_id": request_id or "", + "agent.completion_tokens": completion_tokens, + "agent.prompt_tokens": prompt_tokens, + "agent.total_tokens": total_tokens, + } + return log_data + + def emit_log( + self, + session_id: str | None = None, + request_id: str | None = None, + completion_tokens: int = 0, + prompt_tokens: int = 0, + total_tokens: int = 0, + ) -> dict[str, Any]: + """Emit a structured log message with all collected agent metadata. + + Returns the log data dict for convenience. + """ + log_data = self.get_standardized_log( + session_id=session_id, + request_id=request_id, + completion_tokens=completion_tokens, + prompt_tokens=prompt_tokens, + total_tokens=total_tokens, + ) + self._logger.info("agent_invocation_summary: %s", json.dumps(log_data)) + return log_data + + def enrich_span( + self, + span: trace.Span | None, + session_id: str | None = None, + request_id: str | None = None, + completion_tokens: int = 0, + prompt_tokens: int = 0, + total_tokens: int = 0, + time_to_first_token_ms: float | None = None, + ) -> None: + """Enrich an OpenTelemetry span with all collected agent metadata.""" + if span is None: + return + + span.set_attribute("agent.name", self._agent_name) + span.set_attribute("agent.model", self._model_name) + span.set_attribute("agent.tool_calls", list(self._tool_calls)) + span.set_attribute("agent.tool_call_count", self.tool_call_count) + span.set_attribute( + "agent.internal_function_calls", list(self._internal_function_calls) + ) + span.set_attribute( + "agent.internal_function_call_count", len(self._internal_function_calls) + ) + span.set_attribute("agent.invocation_count", self._invocation_count) + span.set_attribute("agent.user_isid", self._user_isid or "") + span.set_attribute("agent.session_id", session_id or "") + span.set_attribute("agent.request_id", request_id or "") + span.set_attribute("agent.completion_tokens", completion_tokens) + span.set_attribute("agent.prompt_tokens", prompt_tokens) + span.set_attribute("agent.total_tokens", total_tokens) + + if self._reasoning_entries: + span.set_attribute("agent.reasoning", list(self._reasoning_entries)) + + if time_to_first_token_ms is not None: + span.add_event( + "agent_time_to_first_token", + attributes={"first_token_time_ms": time_to_first_token_ms}, + ) + + @contextmanager + def trace_agent_invocation( + self, + span_name: str, + session_id: str | None = None, + request_id: str | None = None, + ): + """Context manager that creates a span for agent invocation and + automatically enriches it on exit with all collected metadata. + + Usage:: + + with agent_logger.trace_agent_invocation("handler-invoke") as span: + # ... perform agent work, record tool calls, etc. + pass + # span is automatically enriched on exit + + Yields the span (or ``None`` if telemetry is disabled). + """ + self.record_invocation() + + if ( + self._telemetry is not None + and self._telemetry.telemetry_enabled() + and self._telemetry.tracer + ): + with self._telemetry.tracer.start_as_current_span(span_name) as span: + yield span + else: + yield None + + _services_telemetry: Telemetry | None = None diff --git a/shared/ska_utils/tests/test_telemetry.py b/shared/ska_utils/tests/test_telemetry.py index c266570fd..7cfa8930a 100644 --- a/shared/ska_utils/tests/test_telemetry.py +++ b/shared/ska_utils/tests/test_telemetry.py @@ -1,3 +1,4 @@ +import json import logging from unittest.mock import MagicMock, patch @@ -5,6 +6,7 @@ from opentelemetry.trace import Tracer from ska_utils import AppConfig, Telemetry, get_telemetry, initialize_telemetry +from ska_utils.telemetry import AgentTelemetryLogger @pytest.fixture @@ -190,3 +192,271 @@ def test_initialize_telemetry(app_config): initialize_telemetry("test_service", app_config) telemetry = get_telemetry() assert telemetry.service_name == "test_service" + + +# =================================================================== +# Tests for AgentTelemetryLogger +# =================================================================== + + +@pytest.fixture +def telemetry_disabled(app_config): + """Create a telemetry instance with telemetry disabled.""" + app_config.get.side_effect = { + "TA_TELEMETRY_ENABLED": "false", + "TA_METRICS_ENABLED": "false", + "TA_LOGGING_ENABLED": "false", + "TA_OTEL_ENDPOINT": None, + "TA_LOG_LEVEL": "info", + }.get + return Telemetry("test_service", app_config) + + +@pytest.fixture +def telemetry_enabled(app_config): + """Create a telemetry instance with telemetry enabled.""" + return Telemetry("test_service", app_config) + + +@pytest.fixture +def agent_logger(telemetry_disabled): + """Create an AgentTelemetryLogger for testing.""" + return AgentTelemetryLogger( + telemetry=telemetry_disabled, + agent_name="weather_agent", + model_name="gpt-4o", + user_isid="user123", + ) + + +class TestAgentTelemetryLoggerInit: + def test_initial_state(self, agent_logger): + assert agent_logger.agent_name == "weather_agent" + assert agent_logger.model_name == "gpt-4o" + assert agent_logger.user_isid == "user123" + assert agent_logger.tool_calls == [] + assert agent_logger.tool_call_count == 0 + assert agent_logger.internal_function_calls == [] + assert agent_logger.reasoning_entries == [] + assert agent_logger.invocation_count == 0 + + def test_init_without_user_isid(self, telemetry_disabled): + logger = AgentTelemetryLogger( + telemetry=telemetry_disabled, + agent_name="test_agent", + model_name="gpt-4o", + ) + assert logger.user_isid is None + + +class TestAgentTelemetryLoggerRecording: + def test_record_tool_call(self, agent_logger): + agent_logger.record_tool_call("get_weather") + assert agent_logger.tool_calls == ["get_weather"] + assert agent_logger.tool_call_count == 1 + + def test_record_multiple_tool_calls(self, agent_logger): + agent_logger.record_tool_call("get_weather") + agent_logger.record_tool_call("get_location") + assert agent_logger.tool_calls == ["get_weather", "get_location"] + assert agent_logger.tool_call_count == 2 + + def test_record_tool_calls_batch(self, agent_logger): + agent_logger.record_tool_calls(["get_weather", "get_location", "get_time"]) + assert agent_logger.tool_calls == ["get_weather", "get_location", "get_time"] + assert agent_logger.tool_call_count == 3 + + def test_record_internal_function_call(self, agent_logger): + agent_logger.record_internal_function_call("WeatherPlugin.get_weather") + assert agent_logger.internal_function_calls == ["WeatherPlugin.get_weather"] + + def test_record_reasoning(self, agent_logger): + agent_logger.record_reasoning("I need to check the weather first") + assert agent_logger.reasoning_entries == ["I need to check the weather first"] + + def test_record_reasoning_empty_string_ignored(self, agent_logger): + agent_logger.record_reasoning("") + assert agent_logger.reasoning_entries == [] + + def test_record_invocation(self, agent_logger): + agent_logger.record_invocation() + assert agent_logger.invocation_count == 1 + agent_logger.record_invocation() + assert agent_logger.invocation_count == 2 + + def test_tool_calls_returns_copy(self, agent_logger): + agent_logger.record_tool_call("get_weather") + calls = agent_logger.tool_calls + calls.append("mutated") + assert agent_logger.tool_calls == ["get_weather"] + + +class TestAgentTelemetryLoggerStandardizedLog: + def test_get_standardized_log_basic(self, agent_logger): + log_data = agent_logger.get_standardized_log() + assert log_data["agent.name"] == "weather_agent" + assert log_data["agent.model"] == "gpt-4o" + assert log_data["agent.tool_calls"] == [] + assert log_data["agent.tool_call_count"] == 0 + assert log_data["agent.user_isid"] == "user123" + assert log_data["agent.invocation_count"] == 0 + assert log_data["agent.session_id"] == "" + assert log_data["agent.request_id"] == "" + assert log_data["agent.completion_tokens"] == 0 + assert log_data["agent.prompt_tokens"] == 0 + assert log_data["agent.total_tokens"] == 0 + + def test_get_standardized_log_with_tool_calls(self, agent_logger): + agent_logger.record_tool_calls(["get_weather", "get_location"]) + agent_logger.record_internal_function_call("WeatherPlugin.get_weather") + agent_logger.record_reasoning("reasoning_tokens=50") + agent_logger.record_invocation() + + log_data = agent_logger.get_standardized_log( + session_id="sess-123", + request_id="req-456", + completion_tokens=100, + prompt_tokens=200, + total_tokens=300, + ) + assert log_data["agent.name"] == "weather_agent" + assert log_data["agent.model"] == "gpt-4o" + assert log_data["agent.tool_calls"] == ["get_weather", "get_location"] + assert log_data["agent.tool_call_count"] == 2 + assert log_data["agent.internal_function_calls"] == ["WeatherPlugin.get_weather"] + assert log_data["agent.internal_function_call_count"] == 1 + assert log_data["agent.reasoning"] == ["reasoning_tokens=50"] + assert log_data["agent.user_isid"] == "user123" + assert log_data["agent.invocation_count"] == 1 + assert log_data["agent.session_id"] == "sess-123" + assert log_data["agent.request_id"] == "req-456" + assert log_data["agent.completion_tokens"] == 100 + assert log_data["agent.prompt_tokens"] == 200 + assert log_data["agent.total_tokens"] == 300 + + def test_get_standardized_log_no_user_isid(self, telemetry_disabled): + logger = AgentTelemetryLogger( + telemetry=telemetry_disabled, + agent_name="test_agent", + model_name="gpt-4o", + ) + log_data = logger.get_standardized_log() + assert log_data["agent.user_isid"] == "" + + def test_emit_log(self, agent_logger): + agent_logger.record_tool_call("get_weather") + with patch.object(agent_logger._logger, "info") as mock_info: + log_data = agent_logger.emit_log( + session_id="sess-1", + request_id="req-1", + completion_tokens=10, + prompt_tokens=20, + total_tokens=30, + ) + mock_info.assert_called_once() + call_args = mock_info.call_args + # Verify the log message contains JSON + log_message_json = call_args[0][1] + parsed = json.loads(log_message_json) + assert parsed["agent.name"] == "weather_agent" + assert parsed["agent.tool_calls"] == ["get_weather"] + assert parsed["agent.tool_call_count"] == 1 + + # Also verify the returned dict + assert log_data["agent.name"] == "weather_agent" + assert log_data["agent.total_tokens"] == 30 + + +class TestAgentTelemetryLoggerSpanEnrichment: + def test_enrich_span(self, agent_logger): + mock_span = MagicMock() + agent_logger.record_tool_calls(["get_weather", "get_location"]) + agent_logger.record_internal_function_call("WeatherPlugin.get_weather") + agent_logger.record_reasoning("reasoning_tokens=50") + agent_logger.record_invocation() + + agent_logger.enrich_span( + span=mock_span, + session_id="sess-1", + request_id="req-1", + completion_tokens=100, + prompt_tokens=200, + total_tokens=300, + time_to_first_token_ms=42.5, + ) + + mock_span.set_attribute.assert_any_call("agent.name", "weather_agent") + mock_span.set_attribute.assert_any_call("agent.model", "gpt-4o") + mock_span.set_attribute.assert_any_call( + "agent.tool_calls", ["get_weather", "get_location"] + ) + mock_span.set_attribute.assert_any_call("agent.tool_call_count", 2) + mock_span.set_attribute.assert_any_call( + "agent.internal_function_calls", ["WeatherPlugin.get_weather"] + ) + mock_span.set_attribute.assert_any_call("agent.internal_function_call_count", 1) + mock_span.set_attribute.assert_any_call("agent.invocation_count", 1) + mock_span.set_attribute.assert_any_call("agent.user_isid", "user123") + mock_span.set_attribute.assert_any_call("agent.session_id", "sess-1") + mock_span.set_attribute.assert_any_call("agent.request_id", "req-1") + mock_span.set_attribute.assert_any_call("agent.completion_tokens", 100) + mock_span.set_attribute.assert_any_call("agent.prompt_tokens", 200) + mock_span.set_attribute.assert_any_call("agent.total_tokens", 300) + mock_span.set_attribute.assert_any_call("agent.reasoning", ["reasoning_tokens=50"]) + mock_span.add_event.assert_called_once_with( + "agent_time_to_first_token", + attributes={"first_token_time_ms": 42.5}, + ) + + def test_enrich_span_none(self, agent_logger): + # Should not raise when span is None + agent_logger.enrich_span(span=None) + + def test_enrich_span_no_reasoning(self, agent_logger): + mock_span = MagicMock() + agent_logger.enrich_span(span=mock_span) + # Reasoning should not be set if empty + reasoning_calls = [ + call for call in mock_span.set_attribute.call_args_list + if call[0][0] == "agent.reasoning" + ] + assert len(reasoning_calls) == 0 + + def test_enrich_span_no_ttft(self, agent_logger): + mock_span = MagicMock() + agent_logger.enrich_span(span=mock_span) + mock_span.add_event.assert_not_called() + + +class TestAgentTelemetryLoggerTraceContext: + def test_trace_agent_invocation_disabled(self, telemetry_disabled): + logger = AgentTelemetryLogger( + telemetry=telemetry_disabled, + agent_name="test_agent", + model_name="gpt-4o", + ) + with logger.trace_agent_invocation("test-span") as span: + assert span is None + assert logger.invocation_count == 1 + + def test_trace_agent_invocation_enabled(self, telemetry_enabled): + logger = AgentTelemetryLogger( + telemetry=telemetry_enabled, + agent_name="test_agent", + model_name="gpt-4o", + ) + with logger.trace_agent_invocation("test-span", session_id="s1") as span: + assert span is not None + assert logger.invocation_count == 1 + + def test_trace_agent_invocation_increments_count(self, telemetry_disabled): + logger = AgentTelemetryLogger( + telemetry=telemetry_disabled, + agent_name="test_agent", + model_name="gpt-4o", + ) + with logger.trace_agent_invocation("span-1"): + pass + with logger.trace_agent_invocation("span-2"): + pass + assert logger.invocation_count == 2 diff --git a/src/orchestrators/assistant-orchestrator/orchestrator/agents.py b/src/orchestrators/assistant-orchestrator/orchestrator/agents.py index b00ff3c31..ea2143db3 100644 --- a/src/orchestrators/assistant-orchestrator/orchestrator/agents.py +++ b/src/orchestrators/assistant-orchestrator/orchestrator/agents.py @@ -3,6 +3,7 @@ from collections.abc import AsyncIterable import requests +import requests.exceptions import websockets from opentelemetry.propagate import inject from pydantic import BaseModel, ConfigDict @@ -13,6 +14,43 @@ logger = logging.getLogger(__name__) +class AgentConnectionError(Exception): + """Raised when an agent cannot be reached (down, DNS failure, refused).""" + def __init__(self, agent_name: str, message: str = ""): + self.agent_name = agent_name + self.message = message or f"Agent '{agent_name}' is not available or cannot be reached." + super().__init__(self.message) + + +class AgentTimeoutError(Exception): + """Raised when an agent does not respond within the timeout period.""" + def __init__(self, agent_name: str, message: str = ""): + self.agent_name = agent_name + self.message = message or f"Agent '{agent_name}' timed out while processing the request." + super().__init__(self.message) + + +class AgentResponseError(Exception): + """Raised when an agent returns a non-200 response.""" + def __init__(self, agent_name: str, status_code: int, detail: str = ""): + self.agent_name = agent_name + self.status_code = status_code + self.detail = detail + self.message = f"Agent '{agent_name}' returned an error (HTTP {status_code}): {detail}" + super().__init__(self.message) + + +class AgentInvalidResponseError(Exception): + """Raised when an agent returns a response that cannot be parsed.""" + def __init__(self, agent_name: str, message: str = ""): + self.agent_name = agent_name + self.message = ( + message + or f"Agent '{agent_name}' returned an invalid or unparseable response." + ) + super().__init__(self.message) + + class MultiModalItem(BaseModel): content_type: str content: str @@ -101,10 +139,35 @@ async def invoke_stream( "Authorization": authorization, } inject(headers) - async with websockets.connect(self.endpoint, additional_headers=headers) as ws: - await ws.send(input_message) - async for message in ws: - yield message + try: + async with websockets.connect(self.endpoint, additional_headers=headers) as ws: + await ws.send(input_message) + async for message in ws: + yield message + except (OSError, ConnectionRefusedError, websockets.exceptions.InvalidURI) as e: + logger.error( + f"Agent '{self.name}' is unreachable via WebSocket at {self.endpoint}: {e}" + ) + raise AgentConnectionError( + self.name, + f"Agent '{self.name}' is not available via WebSocket at " + f"{self.endpoint}. The agent may be down or unreachable.", + ) from e + except websockets.exceptions.WebSocketException as e: + logger.error(f"WebSocket error with agent '{self.name}': {e}") + raise AgentConnectionError( + self.name, + f"WebSocket communication failed with agent '{self.name}': {e}", + ) from e + except TimeoutError as e: + logger.error( + f"Agent '{self.name}' timed out via WebSocket at {self.endpoint}: {e}" + ) + raise AgentTimeoutError( + self.name, + f"Agent '{self.name}' timed out while processing the request " + "via WebSocket.", + ) from e # Origianl def invoke_api( @@ -124,12 +187,57 @@ def invoke_api( } inject(headers) logger.info("Beginning response processing") - response = requests.post(self.endpoint_api, data=input_message, headers=headers) + + try: + response = requests.post( + self.endpoint_api, data=input_message, headers=headers, timeout=120 + ) + except requests.exceptions.ConnectionError as e: + logger.error( + f"Agent '{self.name}' is unreachable at {self.endpoint_api}: {e}" + ) + raise AgentConnectionError( + self.name, + f"Agent '{self.name}' is not available at {self.endpoint_api}. " + "The agent may be down or unreachable.", + ) from e + except requests.exceptions.Timeout as e: + logger.error( + f"Agent '{self.name}' timed out at {self.endpoint_api}: {e}" + ) + raise AgentTimeoutError( + self.name, + f"Agent '{self.name}' timed out while processing the request.", + ) from e + except requests.exceptions.RequestException as e: + logger.error(f"Request to agent '{self.name}' failed: {e}") + raise AgentConnectionError( + self.name, + f"Failed to communicate with agent '{self.name}': {e}", + ) from e if response.status_code != 200: - raise Exception(f"Failed to invoke agent API: {response.status_code} - {response.text}") + detail = "" + try: + error_body = response.json() + detail = error_body.get("detail", response.text) + except Exception: + detail = response.text + logger.error(f"Agent '{self.name}' returned HTTP {response.status_code}: {detail}") + raise AgentResponseError(self.name, response.status_code, detail) + + try: + result = response.json() + except Exception as e: + logger.error(f"Agent '{self.name}' returned invalid JSON: {e}") + raise AgentInvalidResponseError( + self.name, + f"Agent '{self.name}' returned a response that could not " + "be parsed as JSON.", + ) from e + logger.info("Final response complete") - return response.json() + return result async def invoke_sse( self, @@ -148,10 +256,48 @@ async def invoke_sse( } inject(headers) logger.info("Beginning response processing") - response = requests.post(f"{self.endpoint_api}/sse", data=input_message, headers=headers) + + try: + response = requests.post( + f"{self.endpoint_api}/sse", data=input_message, headers=headers, timeout=120 + ) + except requests.exceptions.ConnectionError as e: + logger.error( + f"Agent '{self.name}' is unreachable at " + f"{self.endpoint_api}/sse: {e}" + ) + raise AgentConnectionError( + self.name, + f"Agent '{self.name}' is not available at " + f"{self.endpoint_api}/sse. " + "The agent may be down or unreachable.", + ) from e + except requests.exceptions.Timeout as e: + logger.error( + f"Agent '{self.name}' timed out at " + f"{self.endpoint_api}/sse: {e}" + ) + raise AgentTimeoutError( + self.name, + f"Agent '{self.name}' timed out while processing the request.", + ) from e + except requests.exceptions.RequestException as e: + logger.error(f"Request to agent '{self.name}' failed: {e}") + raise AgentConnectionError( + self.name, + f"Failed to communicate with agent '{self.name}': {e}", + ) from e if response.status_code != 200: - raise Exception(f"Failed to invoke agent API: {response.status_code} - {response.text}") + detail = "" + try: + error_body = response.json() + detail = error_body.get("detail", response.text) + except Exception: + detail = response.text + logger.error(f"Agent '{self.name}' returned HTTP {response.status_code}: {detail}") + raise AgentResponseError(self.name, response.status_code, detail) + logger.info("Final response complete") # Iterate over the response content line by line and yield each decoded line. for line in response.iter_lines(): diff --git a/src/orchestrators/assistant-orchestrator/orchestrator/recipient_chooser.py b/src/orchestrators/assistant-orchestrator/orchestrator/recipient_chooser.py index 3b3e970b8..83d7dc35b 100644 --- a/src/orchestrators/assistant-orchestrator/orchestrator/recipient_chooser.py +++ b/src/orchestrators/assistant-orchestrator/orchestrator/recipient_chooser.py @@ -1,12 +1,22 @@ import json +import logging import requests +import requests.exceptions from opentelemetry.propagate import inject from pydantic import BaseModel, ConfigDict -from agents import RecipientChooserAgent +from agents import ( + AgentConnectionError, + AgentInvalidResponseError, + AgentResponseError, + AgentTimeoutError, + RecipientChooserAgent, +) from model import Conversation +logger = logging.getLogger(__name__) + class ReqAgent(BaseModel): name: str @@ -76,15 +86,76 @@ async def choose_recipient( headers = {"taAgwKey": self.agent.api_key, "Authorization": authorization} inject(headers) - response = requests.post( - self.agent.endpoint, - headers=headers, - data=body_json, - ).json() + + try: + raw_response = requests.post( + self.agent.endpoint, + headers=headers, + data=body_json, + timeout=120, + ) + except requests.exceptions.ConnectionError as e: + logger.error( + f"Agent selector '{self.agent.name}' is unreachable " + f"at {self.agent.endpoint}: {e}" + ) + raise AgentConnectionError( + self.agent.name, + f"Agent selector '{self.agent.name}' is not available at " + f"{self.agent.endpoint}. " + "The service may be down or unreachable.", + ) from e + except requests.exceptions.Timeout as e: + logger.error(f"Agent selector '{self.agent.name}' timed out: {e}") + raise AgentTimeoutError( + self.agent.name, + f"Agent selector '{self.agent.name}' timed out while choosing a recipient.", + ) from e + except requests.exceptions.RequestException as e: + logger.error(f"Request to agent selector '{self.agent.name}' failed: {e}") + raise AgentConnectionError( + self.agent.name, + f"Failed to communicate with agent selector '{self.agent.name}': {e}", + ) from e + + if raw_response.status_code != 200: + detail = "" + try: + error_body = raw_response.json() + detail = error_body.get("detail", raw_response.text) + except Exception: + detail = raw_response.text + logger.error( + f"Agent selector '{self.agent.name}' returned " + f"HTTP {raw_response.status_code}: {detail}" + ) + raise AgentResponseError(self.agent.name, raw_response.status_code, detail) + + try: + response = raw_response.json() + except Exception as e: + logger.error(f"Agent selector '{self.agent.name}' returned invalid JSON: {e}") + raise AgentInvalidResponseError( + self.agent.name, + f"Agent selector '{self.agent.name}' returned a response that could not be parsed.", + ) from e + if response: - response_payload = ResponsePayload(**response) - clean_json = RecipientChooser._clean_output(response_payload.output_raw) - sel_agent: SelectedAgent = SelectedAgent(**json.loads(clean_json)) + try: + response_payload = ResponsePayload(**response) + clean_json = RecipientChooser._clean_output(response_payload.output_raw) + sel_agent: SelectedAgent = SelectedAgent(**json.loads(clean_json)) + except (json.JSONDecodeError, KeyError, Exception) as e: + logger.error( + f"Agent selector '{self.agent.name}' returned " + f"unparseable agent selection: {e}" + ) + raise AgentInvalidResponseError( + self.agent.name, + f"Agent selector '{self.agent.name}' returned a response " + "that could not be parsed into an agent " + f"selection: {e}", + ) from e return sel_agent else: raise Exception("Unable to determine recipient") diff --git a/src/orchestrators/assistant-orchestrator/orchestrator/routes/apis.py b/src/orchestrators/assistant-orchestrator/orchestrator/routes/apis.py index eec272074..aeb75f932 100644 --- a/src/orchestrators/assistant-orchestrator/orchestrator/routes/apis.py +++ b/src/orchestrators/assistant-orchestrator/orchestrator/routes/apis.py @@ -1,10 +1,17 @@ import logging from contextlib import nullcontext +import requests.exceptions from fastapi import APIRouter, Depends, HTTPException from fastapi.security import APIKeyHeader from ska_utils import get_telemetry +from agents import ( + AgentConnectionError, + AgentInvalidResponseError, + AgentResponseError, + AgentTimeoutError, +) from context_directive import parse_context_directives from jose_types import ExtraData from model.requests import ConversationMessageRequest @@ -92,7 +99,77 @@ async def add_conversation_message_by_id( selected_agent = await rec_chooser.choose_recipient( request.message, conv, authorization ) + except AgentConnectionError as e: + logger.error(f"Agent selector service is unreachable: {e}") + raise HTTPException( + status_code=502, + detail=( + f"Agent selector service '{e.agent_name}' is not available. " + "The service may be down or unreachable." + ), + ) from e + except AgentTimeoutError as e: + logger.error(f"Agent selector service timed out: {e}") + raise HTTPException( + status_code=504, + detail=( + f"Agent selector service '{e.agent_name}' " + "timed out while choosing a recipient." + ), + ) from e + except AgentResponseError as e: + logger.error(f"Agent selector service returned error: {e}") + if e.status_code == 401: + raise HTTPException( + status_code=401, + detail=( + f"Agent selector service '{e.agent_name}' " + f"authentication failed: {e.detail}" + ), + ) from e + elif e.status_code == 429: + raise HTTPException( + status_code=429, + detail=( + f"Agent selector service '{e.agent_name}' " + "is rate limited. Please try again later." + ), + ) from e + else: + raise HTTPException( + status_code=502, + detail=( + f"Agent selector service '{e.agent_name}' " + f"returned an error (HTTP {e.status_code}): " + f"{e.detail}" + ), + ) from e + except AgentInvalidResponseError as e: + logger.error(f"Agent selector returned invalid response: {e}") + raise HTTPException( + status_code=502, + detail=( + f"Agent selector service '{e.agent_name}' returned " + "an invalid response that could not be processed." + ), + ) from e + except requests.exceptions.ConnectionError as e: + logger.error(f"Agent selector service is unreachable (connection error): {e}") + raise HTTPException( + status_code=502, + detail=( + "Agent selector service is not available. " + "The service may be down or unreachable." + ), + ) from e + except requests.exceptions.Timeout as e: + logger.error(f"Agent selector service timed out (timeout): {e}") + raise HTTPException( + status_code=504, + detail="Agent selector service timed out while choosing a recipient.", + ) from e except Exception as e: + logger.error(f"Error choosing recipient: {e}") raise HTTPException( status_code=500, detail=f"Error retrieving agent to handle conversation message --- {e}", @@ -124,7 +201,63 @@ async def add_conversation_message_by_id( else nullcontext() ): logger.info("Begin processing invoke_api") - response = agent.invoke_api(conv, authorization, request.image_data) + + try: + response = agent.invoke_api(conv, authorization, request.image_data) + except AgentConnectionError as e: + logger.error(f"Agent unavailable: {e}") + raise HTTPException( + status_code=502, + detail=( + f"Agent '{sel_agent_name}' is not available. " + "The agent may be down or unreachable. " + "Please try again later." + ), + ) from e + except AgentTimeoutError as e: + logger.error(f"Agent timed out: {e}") + raise HTTPException( + status_code=504, + detail=( + f"Agent '{sel_agent_name}' timed out while " + "processing the request. Please try again later." + ), + ) from e + except AgentResponseError as e: + logger.error(f"Agent returned error: {e}") + if e.status_code == 401: + raise HTTPException( + status_code=401, + detail=( + f"Agent '{sel_agent_name}' " + f"authentication failed: {e.detail}" + ), + ) from e + elif e.status_code == 429: + raise HTTPException( + status_code=429, + detail=( + f"Agent '{sel_agent_name}' is rate limited. " + "Please try again later." + ), + ) from e + else: + raise HTTPException( + status_code=502, + detail=( + f"Agent '{sel_agent_name}' returned an error " + f"(HTTP {e.status_code}): {e.detail}" + ), + ) from e + except AgentInvalidResponseError as e: + logger.error(f"Agent returned invalid response: {e}") + raise HTTPException( + status_code=502, + detail=( + f"Agent '{sel_agent_name}' returned an invalid " + "response that could not be processed." + ), + ) from e try: # Set the agent response from raw output diff --git a/src/orchestrators/assistant-orchestrator/orchestrator/routes/sse.py b/src/orchestrators/assistant-orchestrator/orchestrator/routes/sse.py index 0ef014f46..41ac72a86 100644 --- a/src/orchestrators/assistant-orchestrator/orchestrator/routes/sse.py +++ b/src/orchestrators/assistant-orchestrator/orchestrator/routes/sse.py @@ -4,12 +4,18 @@ from contextlib import nullcontext from typing import Any +import requests.exceptions from fastapi import APIRouter, Depends, HTTPException from fastapi.responses import StreamingResponse from fastapi.security import APIKeyHeader from ska_utils import get_telemetry -from agents import Conversation +from agents import ( + AgentConnectionError, + AgentResponseError, + AgentTimeoutError, + Conversation, +) from context_directive import parse_context_directives from jose_types import ExtraData from model.conversation import SseError, SseEventType, SseFinalMessage, SseMessage @@ -78,12 +84,70 @@ async def sse_event_response( selected_agent = await rec_chooser.choose_recipient( request.message, conv, authorization ) + except AgentConnectionError as e: + logger.error(f"Agent selector service is unreachable: {e}") + sse_error = SseError( + error=( + f"Agent selector service '{e.agent_name}' is not " + "available. The service may be down or unreachable." + ), + ) + yield format_sse_message(sse_error.model_dump(), SseEventType.UNKNOWN) + return + except AgentTimeoutError as e: + logger.error(f"Agent selector service timed out: {e}") + sse_error = SseError( + error=( + f"Agent selector service '{e.agent_name}' " + "timed out while choosing a recipient." + ), + ) + yield format_sse_message(sse_error.model_dump(), SseEventType.UNKNOWN) + return + except AgentResponseError as e: + logger.error(f"Agent selector service returned error: {e}") + if e.status_code == 401: + error_msg = ( + f"Agent selector service '{e.agent_name}' " + f"authentication failed: {e.detail}" + ) + elif e.status_code == 429: + error_msg = ( + f"Agent selector service '{e.agent_name}' " + "is rate limited. Please try again later." + ) + else: + error_msg = ( + f"Agent selector service '{e.agent_name}' " + f"returned an error " + f"(HTTP {e.status_code}): {e.detail}" + ) + sse_error = SseError(error=error_msg) + yield format_sse_message(sse_error.model_dump(), SseEventType.UNKNOWN) + return + except requests.exceptions.ConnectionError as e: + logger.error(f"Agent selector service is unreachable (connection error): {e}") + sse_error = SseError( + error=( + "Agent selector service is not available. " + "The service may be down or unreachable." + ), + ) + yield format_sse_message(sse_error.model_dump(), SseEventType.UNKNOWN) + return + except requests.exceptions.Timeout as e: + logger.error(f"Agent selector service timed out (timeout): {e}") + sse_error = SseError( + error="Agent selector service timed out while choosing a recipient.", + ) + yield format_sse_message(sse_error.model_dump(), SseEventType.UNKNOWN) + return except Exception as e: + logger.error(f"Error choosing recipient: {e}") sse_error = SseError( error=f"Error retrieving agent: {e}", ) yield format_sse_message(sse_error.model_dump(), SseEventType.UNKNOWN) - raise e return # Determine the selected agent @@ -150,7 +214,48 @@ async def sse_event_response( except json.JSONDecodeError: print(f"Error decoding JSON: {json_data_str}") await asyncio.sleep(0.001) + except AgentConnectionError as e: + logger.error(f"Agent unavailable during SSE streaming: {e}") + sse_error = SseError( + error=( + f"Agent '{sel_agent_name}' is not available. " + "The agent may be down or unreachable." + ), + ) + yield format_sse_message(sse_error.model_dump(), SseEventType.UNKNOWN) + return + except AgentTimeoutError as e: + logger.error(f"Agent timed out during SSE streaming: {e}") + sse_error = SseError( + error=( + f"Agent '{sel_agent_name}' timed out while " + "processing the request." + ), + ) + yield format_sse_message(sse_error.model_dump(), SseEventType.UNKNOWN) + return + except AgentResponseError as e: + logger.error(f"Agent returned error during SSE streaming: {e}") + if e.status_code == 401: + error_msg = ( + f"Agent '{sel_agent_name}' " + f"authentication failed: {e.detail}" + ) + elif e.status_code == 429: + error_msg = ( + f"Agent '{sel_agent_name}' is rate limited. " + "Please try again later." + ) + else: + error_msg = ( + f"Agent '{sel_agent_name}' returned an error " + f"(HTTP {e.status_code}): {e.detail}" + ) + sse_error = SseError(error=error_msg) + yield format_sse_message(sse_error.model_dump(), SseEventType.UNKNOWN) + return except Exception as e: + logger.error(f"Unexpected error during agent streaming: {e}") sse_error = SseError( error=f"Error during agent streaming: {e}", ) diff --git a/src/orchestrators/assistant-orchestrator/orchestrator/routes/websockets.py b/src/orchestrators/assistant-orchestrator/orchestrator/routes/websockets.py index b5ed4374a..150982e03 100644 --- a/src/orchestrators/assistant-orchestrator/orchestrator/routes/websockets.py +++ b/src/orchestrators/assistant-orchestrator/orchestrator/routes/websockets.py @@ -1,3 +1,4 @@ +import logging from contextlib import nullcontext from fastapi import ( @@ -7,6 +8,11 @@ ) from ska_utils import get_telemetry +from agents import ( + AgentConnectionError, + AgentResponseError, + AgentTimeoutError, +) from context_directive import parse_context_directives from jose_types import ExtraData @@ -19,6 +25,8 @@ get_rec_chooser, ) +logger = logging.getLogger(__name__) + conv_manager = get_conv_manager() conn_manager = get_conn_manager() rec_chooser = get_rec_chooser() @@ -61,9 +69,65 @@ async def invoke_stream( if jt.telemetry_enabled() else nullcontext() ): - selected_agent = await rec_chooser.choose_recipient( - message, conv, authorization - ) + try: + selected_agent = await rec_chooser.choose_recipient( + message, conv, authorization + ) + except AgentConnectionError as e: + logger.error(f"Agent selector service is unreachable: {e}") + await websocket.send_json({ + "error": True, + "error_type": "agent_selector_unavailable", + "message": ( + f"Agent selector service '{e.agent_name}' " + "is not available. The service may be " + "down or unreachable." + ), + }) + continue + except AgentTimeoutError as e: + logger.error(f"Agent selector service timed out: {e}") + await websocket.send_json({ + "error": True, + "error_type": "agent_selector_timeout", + "message": ( + f"Agent selector service '{e.agent_name}' " + "timed out while choosing a recipient." + ), + }) + continue + except AgentResponseError as e: + logger.error(f"Agent selector service returned error: {e}") + if e.status_code == 401: + error_msg = ( + f"Agent selector service '{e.agent_name}' " + f"authentication failed: {e.detail}" + ) + elif e.status_code == 429: + error_msg = ( + f"Agent selector service '{e.agent_name}' " + "is rate limited. Please try again later." + ) + else: + error_msg = ( + f"Agent selector service '{e.agent_name}' " + "returned an error " + f"(HTTP {e.status_code}): {e.detail}" + ) + await websocket.send_json({ + "error": True, + "error_type": "agent_selector_error", + "message": error_msg, + }) + continue + except Exception as e: + logger.error(f"Error choosing recipient: {e}") + await websocket.send_json({ + "error": True, + "error_type": "agent_selector_unavailable", + "message": f"Agent selector service encountered an error: {e}", + }) + continue if selected_agent.agent_name not in agent_catalog.agents: agent = fallback_agent sel_agent_name = fallback_agent.name @@ -96,14 +160,75 @@ async def invoke_stream( ): # Stream agent response to client response = "" - async for content in agent.invoke_stream(conv, authorization=authorization): - try: - extra_data: ExtraData = ExtraData.new_from_json(content) - context_directives = parse_context_directives(extra_data) - await conv_manager.process_context_directives(conv, context_directives) - except Exception: - response = f"{response}{content}" - await websocket.send_text(content) + try: + async for content in agent.invoke_stream(conv, authorization=authorization): + try: + extra_data: ExtraData = ExtraData.new_from_json(content) + context_directives = parse_context_directives(extra_data) + await conv_manager.process_context_directives( + conv, context_directives + ) + except Exception: + response = f"{response}{content}" + await websocket.send_text(content) + except AgentConnectionError as e: + logger.error(f"Agent unavailable during WebSocket streaming: {e}") + await websocket.send_json({ + "error": True, + "error_type": "agent_unavailable", + "message": ( + f"Agent '{sel_agent_name}' is not available. " + "The agent may be down or unreachable." + ), + }) + continue + except AgentTimeoutError as e: + logger.error(f"Agent timed out during WebSocket streaming: {e}") + await websocket.send_json({ + "error": True, + "error_type": "agent_timeout", + "message": ( + f"Agent '{sel_agent_name}' timed out " + "while processing the request." + ), + }) + continue + except AgentResponseError as e: + logger.error(f"Agent returned error during WebSocket streaming: {e}") + if e.status_code == 401: + error_msg = ( + f"Agent '{sel_agent_name}' " + f"authentication failed: {e.detail}" + ) + elif e.status_code == 429: + error_msg = ( + f"Agent '{sel_agent_name}' is rate " + "limited. Please try again later." + ) + else: + error_msg = ( + f"Agent '{sel_agent_name}' returned " + "an error " + f"(HTTP {e.status_code}): {e.detail}" + ) + await websocket.send_json({ + "error": True, + "error_type": "agent_error", + "message": error_msg, + }) + continue + except Exception as e: + logger.error(f"Unexpected error during agent streaming: {e}") + await websocket.send_json({ + "error": True, + "error_type": "unknown_error", + "message": ( + "An unexpected error occurred while " + "communicating with agent " + f"'{sel_agent_name}': {e}" + ), + }) + continue with ( jt.tracer.start_as_current_span("update-history-assistant") diff --git a/src/sk-agents/README.md b/src/sk-agents/README.md index ba693b835..3fe6c6259 100644 --- a/src/sk-agents/README.md +++ b/src/sk-agents/README.md @@ -55,7 +55,7 @@ $ fastapi run src/sk_agents/app.py You can test the agent by visiting http://localhost:8000/docs -![Agent Swagger UI](doc/assets/demo-1.png) +![Agent Swagger UI](docs/assets/demo-1.png) ### Additional Documentation - [Configuring an Agent](/src/sk-agents/demos/01_getting_started/README.md) diff --git a/src/sk-agents/src/sk_agents/exceptions.py b/src/sk-agents/src/sk_agents/exceptions.py index 8cdf5a2b7..08d7130e3 100644 --- a/src/sk-agents/src/sk_agents/exceptions.py +++ b/src/sk-agents/src/sk_agents/exceptions.py @@ -29,6 +29,53 @@ def __init__(self, message: str): self.message = message +class AgentUnavailableException(AgentsException): + """Exception raised when a target agent is unreachable. + + Covers connection refused, DNS failure, and timeout scenarios. + """ + + agent_name: str + message: str + + def __init__(self, agent_name: str, message: str): + self.agent_name = agent_name + self.message = message + super().__init__(f"Agent '{agent_name}' is unavailable: {message}") + + +class LLMAuthenticationException(AgentsException): + """Exception raised when the LLM provider rejects authentication. + + Covers invalid API key, expired token, etc. + """ + + status_code: int + message: str + + def __init__(self, status_code: int, message: str): + self.status_code = status_code + self.message = message + super().__init__(f"LLM authentication failed (HTTP {status_code}): {message}") + + +class LLMServiceException(AgentsException): + """Exception raised for LLM service-level errors. + + Covers rate limits, server errors, model not found, etc. + """ + + error_type: str + message: str + status_code: int | None + + def __init__(self, error_type: str, message: str, status_code: int | None = None): + self.error_type = error_type + self.message = message + self.status_code = status_code + super().__init__(f"LLM service error ({error_type}): {message}") + + class PersistenceCreateError(AgentsException): """Exception raised for errors during task creation.""" diff --git a/src/sk-agents/src/sk_agents/routes.py b/src/sk-agents/src/sk_agents/routes.py index 117e52e4e..424d54258 100644 --- a/src/sk-agents/src/sk_agents/routes.py +++ b/src/sk-agents/src/sk_agents/routes.py @@ -27,6 +27,12 @@ TA_PROVIDER_ORG, TA_PROVIDER_URL, ) +from sk_agents.exceptions import ( + AgentInvokeException, + AgentUnavailableException, + LLMAuthenticationException, + LLMServiceException, +) from sk_agents.persistence.task_persistence_manager import TaskPersistenceManager from sk_agents.ska_types import ( BaseConfig, @@ -230,15 +236,53 @@ async def invoke(inputs: input_class, request: Request) -> InvokeResponse[output if st.telemetry_enabled() else nullcontext() ): - match root_handler_name: - case "skagents": - handler: BaseHandler = skagents_handle(config, app_config, authorization) - case _: - raise ValueError(f"Unknown apiVersion: {config.apiVersion}") - inv_inputs = inputs.__dict__ - output = await handler.invoke(inputs=inv_inputs) - return output + try: + match root_handler_name: + case "skagents": + handler: BaseHandler = skagents_handle( + config, app_config, authorization + ) + case _: + raise ValueError(f"Unknown apiVersion: {config.apiVersion}") + + output = await handler.invoke(inputs=inv_inputs) + return output + except AgentUnavailableException as e: + logger.exception(f"Agent unavailable: {e}") + raise HTTPException( + status_code=502, detail=f"Agent unavailable: {e.message}" + ) from e + except LLMAuthenticationException as e: + logger.exception(f"LLM authentication failed: {e}") + raise HTTPException( + status_code=401, + detail=f"LLM authentication failed: {e.message}", + ) from e + except LLMServiceException as e: + sc = { + "rate_limit": 429, "model_not_found": 404, + "server_error": 502, "content_filter": 400, + }.get(e.error_type, 502) + logger.exception(f"LLM service error: {e}") + raise HTTPException( + status_code=sc, + detail=f"LLM service error ({e.error_type}): {e.message}", + ) from e + except AgentInvokeException as e: + logger.exception(f"Agent invocation failed: {e}") + raise HTTPException( + status_code=500, detail=f"Agent invocation failed: {e.message}" + ) from e + except HTTPException: + raise + except ValueError: + raise + except Exception as e: + logger.exception(f"Unexpected error: {e}") + raise HTTPException( + status_code=500, detail=f"Internal Server Error: {str(e)}" + ) from e @router.post("/sse") @docstring_parameter(description) @@ -267,8 +311,43 @@ async def event_generator(): config, app_config, authorization ) # noinspection PyTypeChecker - async for content in handler.invoke_stream(inputs=inv_inputs): - yield get_sse_event_for_response(content) + try: + async for content in handler.invoke_stream(inputs=inv_inputs): + yield get_sse_event_for_response(content) + except AgentUnavailableException as e: + logger.exception(f"Agent unavailable in SSE: {e}") + yield get_sse_event_for_response({ + "error": f"Agent unavailable: {e.message}", + "status_code": 502, + }) + except LLMAuthenticationException as e: + logger.exception(f"LLM auth failed in SSE: {e}") + yield get_sse_event_for_response({ + "error": f"LLM authentication failed: {e.message}", + "status_code": 401, + }) + except LLMServiceException as e: + logger.exception(f"LLM service error in SSE: {e}") + sc = getattr(e, "status_code", 502) or 502 + yield get_sse_event_for_response({ + "error": ( + f"LLM service error ({e.error_type}): " + f"{e.message}" + ), + "status_code": sc, + }) + except AgentInvokeException as e: + logger.exception(f"Agent invocation failed in SSE: {e}") + yield get_sse_event_for_response({ + "error": f"Agent invocation failed: {e.message}", + "status_code": 500, + }) + except Exception as e: + logger.exception(f"Unexpected error in SSE: {e}") + yield get_sse_event_for_response({ + "error": f"Internal Server Error: {str(e)}", + "status_code": 500, + }) case _: logger.exception( "Unknown apiVersion: %s", config.apiVersion, exc_info=True @@ -367,7 +446,39 @@ async def chat(message: input_class, user_id: str = Depends(get_user_id)) -> Sta teal_handler = Routes.get_task_handler( config, app_config, user_id, state_manager, mcp_discovery_manager ) - response_content = await teal_handler.invoke(user_id, message) + try: + response_content = await teal_handler.invoke(user_id, message) + except AgentUnavailableException as e: + logger.exception(f"Agent unavailable: {e}") + raise HTTPException( + status_code=502, detail=f"Agent unavailable: {e.message}" + ) from e + except LLMAuthenticationException as e: + logger.exception(f"LLM authentication failed: {e}") + raise HTTPException( + status_code=401, + detail=f"LLM authentication failed: {e.message}", + ) from e + except LLMServiceException as e: + sc = { + "rate_limit": 429, "model_not_found": 404, + "server_error": 502, "content_filter": 400, + }.get(e.error_type, 502) + logger.exception(f"LLM service error: {e}") + raise HTTPException( + status_code=sc, + detail=f"LLM service error ({e.error_type}): {e.message}", + ) from e + except AgentInvokeException as e: + logger.exception(f"Agent invocation failed: {e}") + raise HTTPException( + status_code=500, detail=f"Agent invocation failed: {e.message}" + ) from e + except Exception as e: + logger.exception(f"Unexpected error: {e}") + raise HTTPException( + status_code=500, detail=f"Internal Server Error: {str(e)}" + ) from e # Return response with state identifiers status = TaskStatus.COMPLETED.value if type(response_content) is HitlResponse: @@ -399,9 +510,41 @@ async def resume(request_id: str, request: Request, body: ResumeRequest): ) try: return await teal_handler.resume_task(authorization, request_id, body, stream=False) + except AgentUnavailableException as e: + logger.exception(f"Agent unavailable in resume: {e}") + raise HTTPException( + status_code=502, detail=f"Agent unavailable: {e.message}" + ) from e + except LLMAuthenticationException as e: + logger.exception(f"LLM auth failed in resume: {e}") + raise HTTPException( + status_code=401, detail=f"LLM authentication failed: {e.message}" + ) from e + except LLMServiceException as e: + sc = { + "rate_limit": 429, "model_not_found": 404, + "server_error": 502, "content_filter": 400, + }.get(e.error_type, 502) + logger.exception(f"LLM service error in resume: {e}") + raise HTTPException( + status_code=sc, + detail=( + f"LLM service error ({e.error_type}): " + f"{e.message}" + ), + ) from e + except AgentInvokeException as e: + logger.exception(f"Agent invocation failed in resume: {e}") + raise HTTPException( + status_code=500, + detail=f"Agent invocation failed: {e.message}", + ) from e except Exception as e: logger.exception(f"Error in resume: {e}") - raise HTTPException(status_code=500, detail="Internal Server Error") from e + raise HTTPException( + status_code=500, + detail="Internal Server Error", + ) from e @router.post("/tealagents/v1alpha1/resume/{request_id}/sse") async def resume_sse(request_id: str, request: Request, body: ResumeRequest): @@ -416,9 +559,37 @@ async def event_generator(): authorization, request_id, body, stream=True ): yield get_sse_event_for_response(content) + except AgentUnavailableException as e: + logger.exception(f"Agent unavailable in resume_sse: {e}") + yield get_sse_event_for_response({ + "error": f"Agent unavailable: {e.message}", + "status_code": 502, + }) + except LLMAuthenticationException as e: + logger.exception(f"LLM auth failed in resume_sse: {e}") + yield get_sse_event_for_response({ + "error": f"LLM authentication failed: {e.message}", + "status_code": 401, + }) + except LLMServiceException as e: + logger.exception(f"LLM service error in resume_sse: {e}") + sc = getattr(e, 'status_code', 502) or 502 + yield get_sse_event_for_response({ + "error": ( + f"LLM service error ({e.error_type}): " + f"{e.message}" + ), + "status_code": sc, + }) + except AgentInvokeException as e: + logger.exception(f"Agent invocation failed in resume_sse: {e}") + yield get_sse_event_for_response({ + "error": f"Agent invocation failed: {e.message}", + "status_code": 500, + }) except Exception as e: logger.exception(f"Error in resume_sse: {e}") - raise HTTPException(status_code=500, detail="Internal Server Error") from e + raise return StreamingResponse(event_generator(), media_type="text/event-stream") diff --git a/src/sk-agents/src/sk_agents/skagents/v1/chat/chat_agents.py b/src/sk-agents/src/sk_agents/skagents/v1/chat/chat_agents.py index 78e4f1a7b..6083f4ee9 100644 --- a/src/sk-agents/src/sk_agents/skagents/v1/chat/chat_agents.py +++ b/src/sk-agents/src/sk_agents/skagents/v1/chat/chat_agents.py @@ -2,14 +2,21 @@ import time import uuid from collections.abc import AsyncIterable -from contextlib import nullcontext from typing import Any +import httpx +import openai from semantic_kernel.contents import ChatMessageContent, TextContent from semantic_kernel.contents.chat_history import ChatHistory from semantic_kernel.contents.utils.author_role import AuthorRole -from ska_utils import get_telemetry +from semantic_kernel.exceptions import ServiceResponseException +from ska_utils import AgentTelemetryLogger, get_telemetry +from sk_agents.exceptions import ( + AgentUnavailableException, + LLMAuthenticationException, + LLMServiceException, +) from sk_agents.extra_data_collector import ExtraDataCollector, ExtraDataPartial from sk_agents.ska_types import ( BaseConfig, @@ -50,6 +57,76 @@ def __init__(self, config: BaseConfig, agent_builder: AgentBuilder, is_v2: bool self.agent_builder = agent_builder + @staticmethod + def _extract_agent_name_from_error(error: Exception) -> str: + """Extract agent name from error message if possible.""" + error_str = str(error) + for prefix in ["Error invoking ", "Error calling "]: + if prefix in error_str: + rest = error_str.split(prefix, 1)[1] + if ":" in rest: + return rest.split(":")[0] + return "unknown" + + @staticmethod + def _classify_llm_error(error: Exception) -> None: + """ + This unwraps ServiceResponseException from Semantic Kernel + and raises the appropriate specific exception. + If no specific classification is found, returns without raising. + """ + if isinstance(error, ServiceResponseException): + inner = error.__cause__ or error + else: + inner = error + + # Check for openai-specific exceptions + if isinstance(inner, openai.AuthenticationError): + raise LLMAuthenticationException( + status_code=401, + message=f"LLM authentication failed: {str(inner)}" + ) from error + if isinstance(inner, openai.PermissionDeniedError): + raise LLMAuthenticationException( + status_code=403, + message=f"LLM permission denied: {str(inner)}" + ) from error + if isinstance(inner, openai.RateLimitError): + raise LLMServiceException( + error_type="rate_limit", + message=f"LLM rate limit exceeded: {str(inner)}", + status_code=429, + ) from error + if isinstance(inner, openai.NotFoundError): + raise LLMServiceException( + error_type="model_not_found", + message=f"LLM model or resource not found: {str(inner)}", + status_code=404, + ) from error + if isinstance(inner, openai.APIStatusError): + raise LLMServiceException( + error_type="service_error", + message=f"LLM API error: {str(inner)}", + status_code=getattr(inner, "status_code", None), + ) from error + if isinstance(inner, openai.APIConnectionError): + raise LLMServiceException( + error_type="connection_error", + message=f"Cannot connect to LLM service: {str(inner)}", + ) from error + if isinstance(inner, openai.APITimeoutError): + raise LLMServiceException( + error_type="timeout", + message=f"LLM request timed out: {str(inner)}", + ) from error + + # If it was a ServiceResponseException but not openai, still raise as LLM error + if isinstance(error, ServiceResponseException): + raise LLMServiceException( + error_type="service_error", + message=f"LLM service error: {str(error)}", + ) from error + @staticmethod def _augment_with_user_context( inputs: dict[str, Any] | None, chat_history: ChatHistory @@ -62,6 +139,14 @@ def _augment_with_user_context( ChatMessageContent(role=AuthorRole.USER, items=[TextContent(text=content)]) ) + @staticmethod + def _extract_user_isid(inputs: dict[str, Any] | None) -> str | None: + """Extract user ISID from inputs user_context if available.""" + if inputs and "user_context" in inputs and inputs["user_context"]: + user_context = inputs["user_context"] + return user_context.get("user.isid") or user_context.get("isid") + return None + async def invoke_stream( self, inputs: dict[str, Any] | None = None ) -> AsyncIterable[PartialResponse | InvokeResponse]: @@ -69,6 +154,16 @@ async def invoke_stream( extra_data_collector = ExtraDataCollector() agent = self.agent_builder.build_agent(self.config.get_agent(), extra_data_collector) + # Initialize agent telemetry logger with rich metadata + agent_config = self.config.get_agent() + user_isid = ChatAgents._extract_user_isid(inputs) + agent_telemetry = AgentTelemetryLogger( + agent_name=agent_config.name, + model_name=agent_config.model, + user_isid=user_isid, + telemetry=jt, + ) + # Initialize tasks count and token metrics completion_tokens: int = 0 prompt_tokens: int = 0 @@ -87,50 +182,97 @@ async def invoke_stream( request_id = str(uuid.uuid4().hex) # Process the final task with streaming - with ( - jt.tracer.start_as_current_span("handler-stream") - if jt.telemetry_enabled() - else nullcontext() + with agent_telemetry.trace_agent_invocation( + "handler-stream", session_id=session_id, request_id=request_id ) as stream_span: first_token_received = False start_time = time.time() + titme_to_first_token_ms = 0.0 logger.info("Beginning processing invoke stream") - async for chunk in agent.invoke_stream(chat_history): - if not first_token_received: - first_token_time = time.time() - titme_to_first_token_ms = (first_token_time - start_time) * 1000 - first_token_received = True - # Initialize content as the partial message in chunk - content = chunk.content - # Calculate usage metrics - call_usage = get_token_usage_for_response(agent.get_model_type(), chunk) - completion_tokens += call_usage.completion_tokens - prompt_tokens += call_usage.prompt_tokens - total_tokens += call_usage.total_tokens - try: - # Attempt to parse as ExtraDataPartial - extra_data_partial: ExtraDataPartial = ExtraDataPartial.new_from_json(content) - extra_data_collector.add_extra_data_items(extra_data_partial.extra_data) - except Exception: - if len(content) > 0: - # Handle and return partial response - final_response.append(content) - yield PartialResponse( - session_id=session_id, - source=f"{self.name}:{self.version}", - request_id=request_id, - output_partial=content, + try: + async for chunk in agent.invoke_stream(chat_history): + if not first_token_received: + first_token_time = time.time() + titme_to_first_token_ms = (first_token_time - start_time) * 1000 + first_token_received = True + # Initialize content as the partial message in chunk + content = chunk.content + # Calculate usage metrics + call_usage = get_token_usage_for_response(agent.get_model_type(), chunk) + completion_tokens += call_usage.completion_tokens + prompt_tokens += call_usage.prompt_tokens + total_tokens += call_usage.total_tokens + try: + # Attempt to parse as ExtraDataPartial + extra_data_partial: ExtraDataPartial = ( + ExtraDataPartial.new_from_json(content) + ) + extra_data_collector.add_extra_data_items( + extra_data_partial.extra_data ) + except Exception: + if len(content) > 0: + # Handle and return partial response + final_response.append(content) + yield PartialResponse( + session_id=session_id, + source=f"{self.name}:{self.version}", + request_id=request_id, + output_partial=content, + ) + except httpx.ConnectError as e: + agent_name = self._extract_agent_name_from_error(e) + raise AgentUnavailableException( + agent_name=agent_name, + message=f"Connection refused: {str(e)}", + ) from e + except httpx.HTTPStatusError as e: + agent_name = self._extract_agent_name_from_error(e) + if e.response.status_code in (502, 503, 504): + raise AgentUnavailableException( + agent_name=agent_name, + message=f"Agent returned HTTP {e.response.status_code}: {str(e)}", + ) from e + self._classify_llm_error(e) + except (ServiceResponseException, openai.OpenAIError) as e: + self._classify_llm_error(e) + raise # Build the final response with InvokeResponse logger.info("Building the final response with InvokeRespons") - if stream_span: - stream_span.set_attribute("completion_tokens", completion_tokens) - stream_span.set_attribute("prompt_tokens", prompt_tokens) - stream_span.set_attribute("total_tokens", total_tokens) - stream_span.add_event( - "agent_time_to_first_token", - attributes={"first_token_time_ms": titme_to_first_token_ms}, - ) + + # Record tool calls made by the agent during streaming + tool_calls = getattr(agent, "last_tool_calls", None) + if isinstance(tool_calls, list) and tool_calls: + agent_telemetry.record_tool_calls(tool_calls) + + # Record reasoning tokens from the agent (always report) + reasoning_tokens = getattr(agent, "last_reasoning_tokens", 0) + if not isinstance(reasoning_tokens, int): + reasoning_tokens = 0 + agent_telemetry.record_reasoning( + f"reasoning_tokens={reasoning_tokens}" + ) + + # Enrich span with agent metadata + agent_telemetry.enrich_span( + span=stream_span, + session_id=session_id, + request_id=request_id, + completion_tokens=completion_tokens, + prompt_tokens=prompt_tokens, + total_tokens=total_tokens, + time_to_first_token_ms=titme_to_first_token_ms, + ) + + # Emit standardized structured log + agent_telemetry.emit_log( + session_id=session_id, + request_id=request_id, + completion_tokens=completion_tokens, + prompt_tokens=prompt_tokens, + total_tokens=total_tokens, + ) + final_response = "".join(final_response) response = InvokeResponse( session_id=session_id, @@ -161,41 +303,96 @@ async def invoke( prompt_tokens: int = 0 total_tokens: int = 0 jt = get_telemetry() + + # Initialize agent telemetry logger with rich metadata + agent_config = self.config.get_agent() + user_isid = ChatAgents._extract_user_isid(inputs) + agent_telemetry = AgentTelemetryLogger( + agent_name=agent_config.name, + model_name=agent_config.model, + user_isid=user_isid, + telemetry=jt, + ) + session_id: str if "session_id" in inputs and inputs["session_id"]: session_id = inputs["session_id"] else: session_id = str(uuid.uuid4().hex) request_id = str(uuid.uuid4().hex) - with ( - jt.tracer.start_as_current_span("handler-invoke") - if jt.telemetry_enabled() - else nullcontext() + + with agent_telemetry.trace_agent_invocation( + "handler-invoke", session_id=session_id, request_id=request_id ) as invoke_span: first_token_received = False start_time = time.time() + titme_to_first_token_ms = 0.0 logger.info("Beginning processing invoke") - - async for content in agent.invoke(chat_history): - if not first_token_received: - first_token_time = time.time() - titme_to_first_token_ms = (first_token_time - start_time) * 1000 - first_token_received = True - response_content.append(content) - call_usage = get_token_usage_for_response(agent.get_model_type(), content) - completion_tokens += call_usage.completion_tokens - prompt_tokens += call_usage.prompt_tokens - total_tokens += call_usage.total_tokens + try: + async for content in agent.invoke(chat_history): + if not first_token_received: + first_token_time = time.time() + titme_to_first_token_ms = (first_token_time - start_time) * 1000 + first_token_received = True + response_content.append(content) + call_usage = get_token_usage_for_response(agent.get_model_type(), content) + completion_tokens += call_usage.completion_tokens + prompt_tokens += call_usage.prompt_tokens + total_tokens += call_usage.total_tokens + except httpx.ConnectError as e: + agent_name = self._extract_agent_name_from_error(e) + raise AgentUnavailableException( + agent_name=agent_name, + message=f"Connection refused: {str(e)}", + ) from e + except httpx.HTTPStatusError as e: + agent_name = self._extract_agent_name_from_error(e) + if e.response.status_code in (502, 503, 504): + raise AgentUnavailableException( + agent_name=agent_name, + message=f"Agent returned HTTP {e.response.status_code}: {str(e)}", + ) from e + self._classify_llm_error(e) + raise + except (ServiceResponseException, openai.OpenAIError) as e: + self._classify_llm_error(e) + raise logger.info("Building the final response with InvokeRespons") - if invoke_span: - invoke_span.set_attribute("completion_tokens", completion_tokens) - invoke_span.set_attribute("prompt_tokens", prompt_tokens) - invoke_span.set_attribute("total_tokens", total_tokens) - invoke_span.add_event( - "agent_response_time_ms", - attributes={"response_time_ms": titme_to_first_token_ms}, - ) + + # Record tool calls made by the agent during invocation + tool_calls = getattr(agent, "last_tool_calls", None) + if isinstance(tool_calls, list) and tool_calls: + agent_telemetry.record_tool_calls(tool_calls) + + # Record reasoning tokens from the agent (always report) + reasoning_tokens = getattr(agent, "last_reasoning_tokens", 0) + if not isinstance(reasoning_tokens, int): + reasoning_tokens = 0 + agent_telemetry.record_reasoning( + f"reasoning_tokens={reasoning_tokens}" + ) + + # Enrich span with agent metadata + agent_telemetry.enrich_span( + span=invoke_span, + session_id=session_id, + request_id=request_id, + completion_tokens=completion_tokens, + prompt_tokens=prompt_tokens, + total_tokens=total_tokens, + time_to_first_token_ms=titme_to_first_token_ms, + ) + + # Emit standardized structured log + agent_telemetry.emit_log( + session_id=session_id, + request_id=request_id, + completion_tokens=completion_tokens, + prompt_tokens=prompt_tokens, + total_tokens=total_tokens, + ) + response = InvokeResponse( session_id=session_id, source=f"{self.name}:{self.version}", diff --git a/src/sk-agents/src/sk_agents/skagents/v1/sequential/sequential_skagents.py b/src/sk-agents/src/sk_agents/skagents/v1/sequential/sequential_skagents.py index ff34f20f4..e7678c602 100644 --- a/src/sk-agents/src/sk_agents/skagents/v1/sequential/sequential_skagents.py +++ b/src/sk-agents/src/sk_agents/skagents/v1/sequential/sequential_skagents.py @@ -3,14 +3,21 @@ import time import uuid from collections.abc import AsyncIterable -from contextlib import nullcontext from copy import deepcopy from typing import Any +import openai from semantic_kernel.contents.chat_history import ChatHistory -from ska_utils import get_telemetry +from semantic_kernel.exceptions import ServiceResponseException +from ska_utils import AgentTelemetryLogger, get_telemetry -from sk_agents.exceptions import AgentInvokeException, InvalidConfigException +from sk_agents.exceptions import ( + AgentInvokeException, + AgentUnavailableException, + InvalidConfigException, + LLMAuthenticationException, + LLMServiceException, +) from sk_agents.extra_data_collector import ExtraDataCollector, ExtraDataPartial from sk_agents.ska_types import ( BaseConfig, @@ -114,6 +121,14 @@ def _parse_task_inputs( task_inputs = None return task_inputs + @staticmethod + def _extract_user_isid(inputs: dict[str, Any] | None) -> str | None: + """Extract user ISID from inputs user_context if available.""" + if inputs and "user_context" in inputs and inputs["user_context"]: + user_context = inputs["user_context"] + return user_context.get("user.isid") or user_context.get("isid") + return None + async def invoke_stream( self, inputs: dict[str, Any] | None = None ) -> AsyncIterable[PartialResponse | IntermediateTaskResponse | InvokeResponse]: @@ -136,11 +151,24 @@ async def invoke_stream( else: session_id = str(uuid.uuid4().hex) request_id = str(uuid.uuid4().hex) + + # Determine agent info from the last task's agent for telemetry + last_agent_config = None + agents = self.config.get_agents() + if agents: + last_agent_config = agents[-1] + + user_isid = SequentialSkagents._extract_user_isid(inputs) + agent_telemetry = AgentTelemetryLogger( + agent_name=last_agent_config.name if last_agent_config else self.name, + model_name=last_agent_config.model if last_agent_config else "unknown", + user_isid=user_isid, + telemetry=jt, + ) + average_ttft_ms = [] - with ( - jt.tracer.start_as_current_span("handler-stream") - if jt.telemetry_enabled() - else nullcontext() + with agent_telemetry.trace_agent_invocation( + "handler-stream", session_id=session_id, request_id=request_id ) as stream_span: logger.info("Beginning processing invoke stream") @@ -168,12 +196,69 @@ async def invoke_stream( total_tokens += i_response.token_usage.total_tokens collector.add_extra_data_items(i_response.extra_data) task_no += 1 + + # Record tool calls made by the agent during this task + tool_calls = getattr(task.agent, "last_tool_calls", None) + if isinstance(tool_calls, list) and tool_calls: + agent_telemetry.record_tool_calls(tool_calls) + + # Record reasoning tokens from this task (always report) + reasoning_tokens = getattr(task.agent, "last_reasoning_tokens", 0) + if not isinstance(reasoning_tokens, int): + reasoning_tokens = 0 + agent_telemetry.record_reasoning( + f"task:{task.name}:reasoning_tokens={reasoning_tokens}" + ) + + # Record each intermediate task invocation + agent_telemetry.record_internal_function_call( + f"task:{task.name}" + ) + yield IntermediateTaskResponse( task_no=task_no, task_name=task.name, response=i_response, ) except Exception as e: + if isinstance( + e, + ( + AgentUnavailableException, + LLMAuthenticationException, + LLMServiceException, + ), + ): + raise + if isinstance(e, ServiceResponseException): + inner = e.__cause__ or e + if isinstance(inner, openai.AuthenticationError): + raise LLMAuthenticationException( + status_code=401, + message=( + "LLM authentication failed for " + f"{self.name}:{self.version}: " + f"{str(inner)}" + ), + ) from e + if isinstance(inner, openai.RateLimitError): + raise LLMServiceException( + error_type="rate_limit", + message=( + f"LLM rate limit for " + f"{self.name}:{self.version}: " + f"{str(inner)}" + ), + status_code=429, + ) from e + raise LLMServiceException( + error_type="service_error", + message=( + f"LLM service error for " + f"{self.name}:{self.version}: " + f"{str(inner)}" + ), + ) from e raise AgentInvokeException( f"Error invoking {self.name}:{self.version} " f"for Session-id {session_id}, Request-id {request_id}, " @@ -213,15 +298,42 @@ async def invoke_stream( request_id=request_id, output_partial=content, ) - if stream_span: - stream_span.set_attribute("completion_tokens", completion_tokens) - stream_span.set_attribute("prompt_tokens", prompt_tokens) - stream_span.set_attribute("total_tokens", total_tokens) - average_ttft = sum(average_ttft_ms) / len(average_ttft_ms) if average_ttft_ms else 0 - stream_span.add_event( - "agent_time_to_first_token", - attributes={"first_token_time_ms": average_ttft}, - ) + + average_ttft = sum(average_ttft_ms) / len(average_ttft_ms) if average_ttft_ms else 0 + + # Record tool calls made by the final task's agent + final_tool_calls = getattr(self.tasks[-1].agent, "last_tool_calls", None) + if isinstance(final_tool_calls, list) and final_tool_calls: + agent_telemetry.record_tool_calls(final_tool_calls) + + # Record reasoning tokens from the final task (always report) + final_reasoning = getattr(self.tasks[-1].agent, "last_reasoning_tokens", 0) + if not isinstance(final_reasoning, int): + final_reasoning = 0 + agent_telemetry.record_reasoning( + f"task:{self.tasks[-1].name}:reasoning_tokens={final_reasoning}" + ) + + # Enrich span with agent metadata + agent_telemetry.enrich_span( + span=stream_span, + session_id=session_id, + request_id=request_id, + completion_tokens=completion_tokens, + prompt_tokens=prompt_tokens, + total_tokens=total_tokens, + time_to_first_token_ms=average_ttft, + ) + + # Emit standardized structured log + agent_telemetry.emit_log( + session_id=session_id, + request_id=request_id, + completion_tokens=completion_tokens, + prompt_tokens=prompt_tokens, + total_tokens=total_tokens, + ) + logger.info( f"{self.name}:{self.version} responded with {total_tokens} tokens. " f"Session-id {session_id}, Request-id {request_id}" @@ -268,10 +380,23 @@ async def invoke(self, inputs: dict[str, Any] | None = None) -> InvokeResponse: else: session_id = str(uuid.uuid4().hex) request_id = str(uuid.uuid4().hex) - with ( - jt.tracer.start_as_current_span("handler-invoke") - if jt.telemetry_enabled() - else nullcontext() + + # Determine agent info from the last task's agent for telemetry + last_agent_config = None + agents = self.config.get_agents() + if agents: + last_agent_config = agents[-1] + + user_isid = SequentialSkagents._extract_user_isid(inputs) + agent_telemetry = AgentTelemetryLogger( + agent_name=last_agent_config.name if last_agent_config else self.name, + model_name=last_agent_config.model if last_agent_config else "unknown", + user_isid=user_isid, + telemetry=jt, + ) + + with agent_telemetry.trace_agent_invocation( + "handler-invoke", session_id=session_id, request_id=request_id ) as invoke_span: average_ttft_ms = [] logger.info("Beginning processing invoke") @@ -288,23 +413,93 @@ async def invoke(self, inputs: dict[str, Any] | None = None) -> InvokeResponse: total_tokens += i_response.token_usage.total_tokens collector.add_extra_data_items(i_response.extra_data) task_no += 1 + + # Record tool calls made by the agent during this task + tool_calls = getattr(task.agent, "last_tool_calls", None) + if isinstance(tool_calls, list) and tool_calls: + agent_telemetry.record_tool_calls(tool_calls) + + # Record reasoning tokens from this task (always report) + reasoning_tokens = getattr(task.agent, "last_reasoning_tokens", 0) + if not isinstance(reasoning_tokens, int): + reasoning_tokens = 0 + agent_telemetry.record_reasoning( + f"task:{task.name}:reasoning_tokens={reasoning_tokens}" + ) + + # Record each task invocation + agent_telemetry.record_internal_function_call( + f"task:{task.name}" + ) except Exception as e: + if isinstance( + e, + ( + AgentUnavailableException, + LLMAuthenticationException, + LLMServiceException, + ), + ): + raise + if isinstance(e, ServiceResponseException): + inner = e.__cause__ or e + if isinstance(inner, openai.AuthenticationError): + raise LLMAuthenticationException( + status_code=401, + message=( + "LLM authentication failed for " + f"{self.name}:{self.version}: " + f"{str(inner)}" + ), + ) from e + if isinstance(inner, openai.RateLimitError): + raise LLMServiceException( + error_type="rate_limit", + message=( + f"LLM rate limit for " + f"{self.name}:{self.version}: " + f"{str(inner)}" + ), + status_code=429, + ) from e + raise LLMServiceException( + error_type="service_error", + message=( + f"LLM service error for " + f"{self.name}:{self.version}: " + f"{str(inner)}" + ), + ) from e raise AgentInvokeException( f"Error invoking {self.name}:{self.version} " f"for Session-id {session_id}, Request-id {request_id}, " f"Task description {task.description}. Error: {str(e)}" ) from e - if invoke_span: - invoke_span.set_attribute("completion_tokens", completion_tokens) - invoke_span.set_attribute("prompt_tokens", prompt_tokens) - invoke_span.set_attribute("total_tokens", total_tokens) - average_response_time = ( - sum(average_ttft_ms) / len(average_ttft_ms) if average_ttft_ms else 0 - ) - invoke_span.add_event( - "agent_response_time_ms", - attributes={"response_time_ms": average_response_time}, - ) + + average_response_time = ( + sum(average_ttft_ms) / len(average_ttft_ms) if average_ttft_ms else 0 + ) + + # Enrich span with agent metadata + agent_telemetry.enrich_span( + span=invoke_span, + session_id=session_id, + request_id=request_id, + completion_tokens=completion_tokens, + prompt_tokens=prompt_tokens, + total_tokens=total_tokens, + time_to_first_token_ms=average_response_time, + ) + + # Emit standardized structured log + agent_telemetry.emit_log( + session_id=session_id, + request_id=request_id, + completion_tokens=completion_tokens, + prompt_tokens=prompt_tokens, + total_tokens=total_tokens, + ) + logger.info( f"{self.name}:{self.version} responded with {total_tokens} tokens. " f"Session-id {session_id}, Request-id {request_id}" diff --git a/src/sk-agents/src/sk_agents/skagents/v1/sk_agent.py b/src/sk-agents/src/sk_agents/skagents/v1/sk_agent.py index 0229170cf..9b99a4c6c 100644 --- a/src/sk-agents/src/sk_agents/skagents/v1/sk_agent.py +++ b/src/sk-agents/src/sk_agents/skagents/v1/sk_agent.py @@ -4,11 +4,13 @@ from semantic_kernel.agents import ChatCompletionAgent from semantic_kernel.contents.chat_history import ChatHistory from semantic_kernel.contents.chat_message_content import ChatMessageContent +from semantic_kernel.contents.function_call_content import FunctionCallContent from semantic_kernel.contents.streaming_chat_message_content import ( StreamingChatMessageContent, ) from sk_agents.ska_types import ModelType +from sk_agents.skagents.v1.utils import get_reasoning_tokens_for_response class SKAgent: @@ -21,6 +23,8 @@ def __init__( self.model_name = model_name self.agent = agent self.model_attributes = model_attributes + self._last_tool_calls: list[str] = [] + self._last_reasoning_tokens: int = 0 def get_model_type(self) -> ModelType: return self.model_attributes["model_type"] @@ -28,12 +32,97 @@ def get_model_type(self) -> ModelType: def so_supported(self) -> bool: return self.model_attributes["so_supported"] + @property + def last_tool_calls(self) -> list[str]: + """Return tool call names from the most recent invocation.""" + return list(self._last_tool_calls) + + @property + def last_reasoning_tokens(self) -> int: + """Return the reasoning token count from the most recent invocation.""" + return self._last_reasoning_tokens + + @staticmethod + def _extract_tool_calls_from_messages( + messages: list[ChatMessageContent], + ) -> list[str]: + """Extract tool call names from intermediate messages. + + Semantic Kernel's ChatCompletionAgent handles tool calling internally + and provides intermediate messages (including FunctionCallContent) via + the on_intermediate_message callback. This method inspects those messages + to extract tool names. + """ + tool_calls: list[str] = [] + for message in messages: + for item in message.items: + if isinstance(item, FunctionCallContent): + full_name = ( + f"{item.plugin_name}.{item.function_name}" + if item.plugin_name + else item.function_name + ) + tool_calls.append(full_name) + return tool_calls + + @staticmethod + def _extract_reasoning_from_messages( + messages: list[ChatMessageContent], + ) -> int: + """Extract total reasoning tokens from intermediate messages.""" + total_reasoning = 0 + for message in messages: + total_reasoning += get_reasoning_tokens_for_response(message) + return total_reasoning + async def invoke_stream( self, history: ChatHistory ) -> AsyncIterable[StreamingChatMessageContent]: - async for result in self.agent.invoke_stream(messages=history): + self._last_tool_calls = [] + self._last_reasoning_tokens = 0 + intermediate_messages: list[ChatMessageContent] = [] + + async def _on_intermediate_message(message: ChatMessageContent) -> None: + intermediate_messages.append(message) + + async for result in self.agent.invoke_stream( + messages=history, + on_intermediate_message=_on_intermediate_message, + ): + # Accumulate reasoning tokens from streamed response chunks + self._last_reasoning_tokens += get_reasoning_tokens_for_response( + result.content + ) yield result.content + self._last_tool_calls = SKAgent._extract_tool_calls_from_messages( + intermediate_messages + ) + # Also check intermediate messages for any reasoning tokens + self._last_reasoning_tokens += SKAgent._extract_reasoning_from_messages( + intermediate_messages + ) async def invoke(self, history: ChatHistory) -> AsyncIterable[ChatMessageContent]: - async for result in self.agent.invoke(messages=history): + self._last_tool_calls = [] + self._last_reasoning_tokens = 0 + intermediate_messages: list[ChatMessageContent] = [] + response_messages: list[ChatMessageContent] = [] + + async def _on_intermediate_message(message: ChatMessageContent) -> None: + intermediate_messages.append(message) + + async for result in self.agent.invoke( + messages=history, + on_intermediate_message=_on_intermediate_message, + ): + response_messages.append(result.content) yield result.content + self._last_tool_calls = SKAgent._extract_tool_calls_from_messages( + intermediate_messages + ) + # Extract reasoning from both response messages and intermediate messages + self._last_reasoning_tokens = SKAgent._extract_reasoning_from_messages( + response_messages + ) + SKAgent._extract_reasoning_from_messages( + intermediate_messages + ) diff --git a/src/sk-agents/src/sk_agents/skagents/v1/utils.py b/src/sk-agents/src/sk_agents/skagents/v1/utils.py index 526f030fe..be6aaa889 100644 --- a/src/sk-agents/src/sk_agents/skagents/v1/utils.py +++ b/src/sk-agents/src/sk_agents/skagents/v1/utils.py @@ -90,3 +90,29 @@ def get_token_usage_for_google_response( content.inner_content.usage.output_tokens + content.inner_content.usage.input_tokens ), ) + + +def get_reasoning_tokens_for_response(content: ChatMessageContent) -> int: + """Extract reasoning/thinking token count from a response if available. + + Currently supports OpenAI models that report ``completion_tokens_details.reasoning_tokens``. + Returns 0 when the information is not present. + """ + try: + if ( + isinstance(content, ChatMessageContent) + and hasattr(content, "inner_content") + and content.inner_content is not None + and hasattr(content.inner_content, "usage") + and content.inner_content.usage is not None + and hasattr(content.inner_content.usage, "completion_tokens_details") + and content.inner_content.usage.completion_tokens_details is not None + and hasattr( + content.inner_content.usage.completion_tokens_details, "reasoning_tokens" + ) + ): + reasoning = content.inner_content.usage.completion_tokens_details.reasoning_tokens + return reasoning if reasoning and reasoning > 0 else 0 + except Exception: + pass + return 0 diff --git a/src/sk-agents/src/sk_agents/tealagents/v1alpha1/agent/handler.py b/src/sk-agents/src/sk_agents/tealagents/v1alpha1/agent/handler.py index 69e3aab73..cd71321c8 100755 --- a/src/sk-agents/src/sk_agents/tealagents/v1alpha1/agent/handler.py +++ b/src/sk-agents/src/sk_agents/tealagents/v1alpha1/agent/handler.py @@ -1,11 +1,14 @@ import asyncio import logging +import time import uuid from collections.abc import AsyncIterable from datetime import datetime from functools import reduce from typing import Literal +import httpx +import openai from semantic_kernel.connectors.ai.chat_completion_client_base import ChatCompletionClientBase from semantic_kernel.contents import ChatMessageContent, ImageContent, TextContent from semantic_kernel.contents.chat_history import ChatHistory @@ -13,13 +16,17 @@ from semantic_kernel.contents.function_result_content import FunctionResultContent from semantic_kernel.contents.streaming_chat_message_content import StreamingChatMessageContent from semantic_kernel.contents.utils.author_role import AuthorRole +from semantic_kernel.exceptions import ServiceResponseException from semantic_kernel.kernel import Kernel -from ska_utils import AppConfig +from ska_utils import AgentTelemetryLogger, AppConfig from sk_agents.authorization.dummy_authorizer import DummyAuthorizer from sk_agents.exceptions import ( AgentInvokeException, + AgentUnavailableException, AuthenticationException, + LLMAuthenticationException, + LLMServiceException, PersistenceCreateError, PersistenceLoadError, ) @@ -46,6 +53,13 @@ logger = logging.getLogger(__name__) +def _extract_user_isid_from_message(inputs: UserMessage) -> str | None: + """Extract user ISID from UserMessage user_context if available.""" + if inputs.user_context: + return inputs.user_context.get("user.isid") or inputs.user_context.get("isid") + return None + + class TealAgentsV1Alpha1Handler(BaseHandler): def __init__( self, @@ -243,15 +257,91 @@ async def _invoke_function( kernel: Kernel, fc_content: FunctionCallContent ) -> FunctionResultContent: """Helper to execute a single tool function call.""" - function = kernel.get_function( - fc_content.plugin_name, - fc_content.function_name, - ) - kernel_argument = fc_content.to_kernel_arguments() - function_result = await function.invoke(kernel, kernel_argument) - return FunctionResultContent.from_function_call_content_and_result( - fc_content, function_result - ) + try: + function = kernel.get_function( + fc_content.plugin_name, + fc_content.function_name, + ) + kernel_argument = fc_content.to_kernel_arguments() + function_result = await function.invoke(kernel, kernel_argument) + return FunctionResultContent.from_function_call_content_and_result( + fc_content, function_result + ) + except httpx.ConnectError as e: + raise AgentUnavailableException( + agent_name=fc_content.plugin_name or "unknown", + message=( + f"Connection refused when calling tool " + f"{fc_content.function_name}: {str(e)}" + ), + ) from e + except httpx.HTTPStatusError as e: + if e.response.status_code in (502, 503, 504): + raise AgentUnavailableException( + agent_name=fc_content.plugin_name or "unknown", + message=f"Tool returned HTTP {e.response.status_code}: {str(e)}", + ) from e + raise + except httpx.TimeoutException as e: + raise AgentUnavailableException( + agent_name=fc_content.plugin_name or "unknown", + message=f"Timeout calling tool {fc_content.function_name}: {str(e)}", + ) from e + + @staticmethod + def _classify_llm_error(error: Exception) -> None: + """ + Classify LLM errors from ServiceResponseException or openai errors + and raise the appropriate specific exception. + """ + if isinstance(error, ServiceResponseException): + inner = error.__cause__ or error + else: + inner = error + + if isinstance(inner, openai.AuthenticationError): + raise LLMAuthenticationException( + status_code=401, + message=f"LLM authentication failed: {str(inner)}" + ) from error + if isinstance(inner, openai.PermissionDeniedError): + raise LLMAuthenticationException( + status_code=403, + message=f"LLM permission denied: {str(inner)}" + ) from error + if isinstance(inner, openai.RateLimitError): + raise LLMServiceException( + error_type="rate_limit", + message=f"LLM rate limit exceeded: {str(inner)}", + status_code=429, + ) from error + if isinstance(inner, openai.NotFoundError): + raise LLMServiceException( + error_type="model_not_found", + message=f"LLM model or resource not found: {str(inner)}", + status_code=404, + ) from error + if isinstance(inner, openai.APIStatusError): + raise LLMServiceException( + error_type="service_error", + message=f"LLM API error: {str(inner)}", + status_code=getattr(inner, "status_code", None), + ) from error + if isinstance(inner, openai.APIConnectionError): + raise LLMServiceException( + error_type="connection_error", + message=f"Cannot connect to LLM service: {str(inner)}", + ) from error + if isinstance(inner, openai.APITimeoutError): + raise LLMServiceException( + error_type="timeout", + message=f"LLM request timed out: {str(inner)}", + ) from error + if isinstance(error, ServiceResponseException): + raise LLMServiceException( + error_type="service_error", + message=f"LLM service error: {str(error)}", + ) from error @staticmethod def _augment_with_user_context(inputs: UserMessage, chat_history: ChatHistory) -> None: @@ -878,6 +968,7 @@ async def recursion_invoke( task_id: str, request_id: str, connection_manager=None, + agent_telemetry: AgentTelemetryLogger | None = None, ) -> TealAgentsResponse | HitlResponse: # Initial setup @@ -899,10 +990,29 @@ async def recursion_invoke( agent.agent.kernel, user_id, session_id, self.discovery_manager, connection_manager ) + # Initialize agent telemetry logger if not provided (top-level call) + from ska_utils import get_telemetry + + try: + jt = get_telemetry() + except ValueError: + jt = None + if agent_telemetry is None: + agent_config = self.config.get_agent() + agent_telemetry = AgentTelemetryLogger( + agent_name=agent_config.name, + model_name=agent_config.model, + user_isid=str(user_id) if user_id else None, + telemetry=jt, + ) + + agent_telemetry.record_invocation() + # Prepare metadata completion_tokens: int = 0 prompt_tokens: int = 0 total_tokens: int = 0 + start_time = time.time() try: # Manual tool calling implementation (existing logic) @@ -923,7 +1033,6 @@ async def recursion_invoke( arguments=arguments, ) for response_chunk in responses: - # response_list.extend(response_chunk) chat_history.add_message(response_chunk) response_list.append(response_chunk) @@ -938,18 +1047,32 @@ async def recursion_invoke( prompt_tokens += call_usage.prompt_tokens total_tokens += call_usage.total_tokens + # Check for reasoning/thinking in response metadata + reasoning_tokens = 0 + if hasattr(response, "inner_content") and response.inner_content: + inner = response.inner_content + # OpenAI reasoning tokens + if hasattr(inner, "usage") and inner.usage and hasattr( + inner.usage, "completion_tokens_details" + ): + details = inner.usage.completion_tokens_details + if details and hasattr(details, "reasoning_tokens"): + reasoning_tokens = details.reasoning_tokens or 0 + agent_telemetry.record_reasoning( + f"reasoning_tokens={reasoning_tokens}" + ) + # A response may have multiple items, e.g., multiple tool calls fc_in_response = [ item for item in response.items if isinstance(item, FunctionCallContent) ] if fc_in_response: - # chat_history.add_message(response) - # Add assistant's message to history function_calls.extend(fc_in_response) else: # If no function calls, it's a direct answer final_response = response + token_usage = TokenUsage( completion_tokens=completion_tokens, prompt_tokens=prompt_tokens, @@ -957,8 +1080,27 @@ async def recursion_invoke( ) # If tool calls were returned, execute them if function_calls: + # Record tool calls for telemetry + tool_names = [] + for fc in function_calls: + full_name = ( + f"{fc.plugin_name}.{fc.function_name}" + if fc.plugin_name + else fc.function_name + ) + tool_names.append(full_name) + agent_telemetry.record_tool_calls(tool_names) + await self._manage_function_calls(function_calls, chat_history, kernel) + # Record internal function calls (the actual kernel invocations) + for fc in function_calls: + agent_telemetry.record_internal_function_call( + f"{fc.plugin_name}.{fc.function_name}" + if fc.plugin_name + else fc.function_name + ) + # Make a recursive call to get the final response from the LLM recursive_response = await self.recursion_invoke( inputs=chat_history, @@ -966,6 +1108,7 @@ async def recursion_invoke( task_id=task_id, request_id=request_id, connection_manager=connection_manager, + agent_telemetry=agent_telemetry, ) return recursive_response @@ -984,6 +1127,21 @@ async def recursion_invoke( ) except Exception as e: + if isinstance( + e, + ( + AgentUnavailableException, + LLMAuthenticationException, + LLMServiceException, + ), + ): + raise + try: + self._classify_llm_error(e) + except (LLMAuthenticationException, LLMServiceException): + raise + except Exception: + pass logger.exception( f"Error invoking {self.name}:{self.version}" f"for Session ID {session_id}, Task ID {task_id}," @@ -996,6 +1154,25 @@ async def recursion_invoke( f" Request ID {request_id}, Error message: {str(e)}" ) from e + # Emit standardized structured log at the end of the invocation chain + elapsed_ms = (time.time() - start_time) * 1000 + agent_telemetry.enrich_span( + span=None, # span managed at invoke() level + session_id=session_id, + request_id=request_id, + completion_tokens=completion_tokens, + prompt_tokens=prompt_tokens, + total_tokens=total_tokens, + time_to_first_token_ms=elapsed_ms, + ) + agent_telemetry.emit_log( + session_id=session_id, + request_id=request_id, + completion_tokens=completion_tokens, + prompt_tokens=prompt_tokens, + total_tokens=total_tokens, + ) + # Persist and return response return await self.prepare_agent_response( agent_task, request_id, final_response, token_usage, extra_data_collector @@ -1008,6 +1185,7 @@ async def recursion_invoke_stream( task_id: str, request_id: str, connection_manager=None, + agent_telemetry: AgentTelemetryLogger | None = None, ) -> AsyncIterable[TealAgentsResponse | TealAgentsPartialResponse | HitlResponse]: chat_history = inputs agent_task = await self.state.load_by_request_id(request_id) @@ -1027,6 +1205,24 @@ async def recursion_invoke_stream( agent.agent.kernel, user_id, session_id, self.discovery_manager, connection_manager ) + # Initialize agent telemetry logger if not provided (top-level call) + from ska_utils import get_telemetry + + try: + jt = get_telemetry() + except ValueError: + jt = None + if agent_telemetry is None: + agent_config = self.config.get_agent() + agent_telemetry = AgentTelemetryLogger( + agent_name=agent_config.name, + model_name=agent_config.model, + user_isid=str(user_id) if user_id else None, + telemetry=jt, + ) + + agent_telemetry.record_invocation() + # Prepare metadata final_response = [] completion_tokens: int = 0 @@ -1063,6 +1259,20 @@ async def recursion_invoke_stream( prompt_tokens += call_usage.prompt_tokens total_tokens += call_usage.total_tokens + # Check for reasoning/thinking in response metadata + reasoning_tokens = 0 + if hasattr(response, "inner_content") and response.inner_content: + inner = response.inner_content + if hasattr(inner, "usage") and inner.usage and hasattr( + inner.usage, "completion_tokens_details" + ): + details = inner.usage.completion_tokens_details + if details and hasattr(details, "reasoning_tokens"): + reasoning_tokens = details.reasoning_tokens or 0 + agent_telemetry.record_reasoning( + f"reasoning_tokens={reasoning_tokens}" + ) + if response.content: try: # Attempt to parse as ExtraDataPartial @@ -1098,7 +1308,27 @@ async def recursion_invoke_stream( # If tool calls are present, execute them if function_calls: + # Record tool calls for telemetry + tool_names = [] + for fc in function_calls: + full_name = ( + f"{fc.plugin_name}.{fc.function_name}" + if fc.plugin_name + else fc.function_name + ) + tool_names.append(full_name) + agent_telemetry.record_tool_calls(tool_names) + await self._manage_function_calls(function_calls, chat_history, kernel) + + # Record internal function calls + for fc in function_calls: + agent_telemetry.record_internal_function_call( + f"{fc.plugin_name}.{fc.function_name}" + if fc.plugin_name + else fc.function_name + ) + # Make a recursive call to get the final streamed response async for final_response_chunk in self.recursion_invoke_stream( chat_history, @@ -1106,6 +1336,7 @@ async def recursion_invoke_stream( task_id, request_id, connection_manager=connection_manager, + agent_telemetry=agent_telemetry, ): yield final_response_chunk return @@ -1116,6 +1347,21 @@ async def recursion_invoke_stream( return except Exception as e: + if isinstance( + e, + ( + AgentUnavailableException, + LLMAuthenticationException, + LLMServiceException, + ), + ): + raise + try: + self._classify_llm_error(e) + except (LLMAuthenticationException, LLMServiceException): + raise + except Exception: + pass logger.exception( f"Error invoking stream for {self.name}:{self.version} " f"for Session ID {session_id}, Task ID {task_id}," @@ -1128,7 +1374,16 @@ async def recursion_invoke_stream( f"Request ID {request_id}, Error message: {str(e)}" ) from e - # # Persist and return response + # Emit standardized structured log + agent_telemetry.emit_log( + session_id=session_id, + request_id=request_id, + completion_tokens=completion_tokens, + prompt_tokens=prompt_tokens, + total_tokens=total_tokens, + ) + + # Persist and return response yield await self.prepare_agent_response( agent_task, request_id, final_response, token_usage, extra_data_collector ) diff --git a/src/sk-agents/tests/test_routes.py b/src/sk-agents/tests/test_routes.py index e98b38761..36807421b 100755 --- a/src/sk-agents/tests/test_routes.py +++ b/src/sk-agents/tests/test_routes.py @@ -1200,9 +1200,9 @@ def mock_side_effect(*args, **kwargs): app.include_router(router, prefix="/api") client = TestClient(app) - # Should raise RuntimeError because the response has already started - # when the exception occurs in the SSE stream - with pytest.raises(RuntimeError, match="response already started"): + # Should raise Exception because the exception propagates from the SSE stream + # after the response has already started streaming + with pytest.raises(Exception, match="Stream error"): client.post( "/api/tealagents/v1alpha1/resume/request123/sse", json={"action": "approve"},