From 4a11519d56dba708c0880e1f03906f70fc02ca96 Mon Sep 17 00:00:00 2001 From: William Conti Date: Thu, 5 Feb 2026 18:44:56 -0500 Subject: [PATCH] chore(contrib): decouple HTTP client events with class-based subscribers Refactor the events API to separate pure data events from tracing logic: - Add TracedEvent base class (pure data, no tracing imports) - Rewrite HttpClientRequestEvent to extend TracedEvent with HTTP-specific fields - Create trace_subscribers package with SpanTracingSubscriber base class - Add HttpClientTracingSubscriber for shared HTTP client tracing logic - Update httpx and requests integrations to use new event field names - Update appsec handlers to use new event name (http.client.request) This enables integrations to publish events without importing tracing, while tracing, AppSec, and future products subscribe independently. Co-Authored-By: Claude Opus 4.6 (1M context) --- ddtrace/_trace/trace_handlers.py | 58 +---- ddtrace/_trace/trace_subscribers/__init__.py | 5 + ddtrace/_trace/trace_subscribers/_base.py | 72 ++++++ .../_trace/trace_subscribers/http_client.py | 44 ++++ ddtrace/appsec/_handlers.py | 8 +- ddtrace/contrib/events/__init__.py | 0 ddtrace/contrib/events/http_client.py | 24 ++ ddtrace/contrib/internal/httpx/patch.py | 61 ++++-- ddtrace/contrib/internal/httpx/utils.py | 18 ++ .../contrib/internal/requests/connection.py | 91 +++----- ddtrace/internal/core/__init__.py | 5 + ddtrace/internal/core/events.py | 195 +++++++++++++++++ tests/internal/events/test_context_event.py | 205 ++++++++++++++++++ tests/internal/test_module.py | 3 + 14 files changed, 651 insertions(+), 138 deletions(-) create mode 100644 ddtrace/_trace/trace_subscribers/__init__.py create mode 100644 ddtrace/_trace/trace_subscribers/_base.py create mode 100644 ddtrace/_trace/trace_subscribers/http_client.py create mode 100644 ddtrace/contrib/events/__init__.py create mode 100644 ddtrace/contrib/events/http_client.py create mode 100644 ddtrace/contrib/internal/httpx/utils.py create mode 100644 ddtrace/internal/core/events.py create mode 100644 tests/internal/events/test_context_event.py diff --git a/ddtrace/_trace/trace_handlers.py b/ddtrace/_trace/trace_handlers.py index e5670560869..af529b76856 100644 --- a/ddtrace/_trace/trace_handlers.py +++ b/ddtrace/_trace/trace_handlers.py @@ -47,8 +47,6 @@ from ddtrace.ext.kafka import TOMBSTONE from ddtrace.ext.kafka import TOPIC from ddtrace.internal import core -from ddtrace.internal.compat import ensure_binary -from ddtrace.internal.compat import ensure_text from ddtrace.internal.compat import is_valid_ip from ddtrace.internal.compat import maybe_stringify from ddtrace.internal.constants import COMPONENT @@ -1337,58 +1335,10 @@ def _on_aiokafka_getmany_message( span.link_span(context) -def _on_httpx_request_start(ctx: core.ExecutionContext, call_trace: bool = True, **kwargs) -> None: - span = _start_span(ctx, call_trace, **kwargs) - span._metrics[_SPAN_MEASURED_KEY] = 1 - - request = ctx.get_item("request") - - if trace_utils.distributed_tracing_enabled(config.httpx): - HTTPPropagator.inject(span.context, request.headers) - - -def httpx_url_to_str(url) -> str: - """ - Helper to convert the httpx.URL parts from bytes to a str - """ - scheme = url.raw_scheme - host = url.raw_host - port = url.port - raw_path = url.raw_path - url = scheme + b"://" + host - if port is not None: - url += b":" + ensure_binary(str(port)) - url += raw_path - - return ensure_text(url) - - -def _on_httpx_send_completed( - ctx: core.ExecutionContext, - exc_info: Tuple[Optional[type], Optional[BaseException], Optional[TracebackType]], -) -> None: - span = ctx.span - - request = ctx.get_item("request") - response = ctx.get_item("response") - - try: - trace_utils.set_http_meta( - span, - config.httpx, - method=request.method, - url=httpx_url_to_str(request.url), - target_host=request.url.host, - status_code=response.status_code if response else None, - query=request.url.query, - request_headers=request.headers, - response_headers=response.headers if response else None, - ) - finally: - _finish_span(ctx, exc_info) - - def listen(): + # Import subscriber package — triggers auto-registration via __init_subclass__ + import ddtrace._trace.trace_subscribers # noqa: F401 + core.on("wsgi.request.prepare", _on_request_prepare) core.on("wsgi.request.prepared", _on_request_prepared) core.on("wsgi.app.success", _on_app_success) @@ -1533,7 +1483,6 @@ def listen(): "aiokafka.getmany", ): core.on(f"context.started.{context_name}", _start_span) - core.on("context.started.httpx.request", _on_httpx_request_start) for name in ( "asgi.request", @@ -1571,7 +1520,6 @@ def listen(): # Special/extra handling before calling _finish_span core.on("context.ended.django.cache", _on_django_cache) - core.on("context.ended.httpx.request", _on_httpx_send_completed) listen() diff --git a/ddtrace/_trace/trace_subscribers/__init__.py b/ddtrace/_trace/trace_subscribers/__init__.py new file mode 100644 index 00000000000..17497a6f12d --- /dev/null +++ b/ddtrace/_trace/trace_subscribers/__init__.py @@ -0,0 +1,5 @@ +from ddtrace._trace.trace_subscribers._base import SpanTracingSubscriber # noqa: F401 +from ddtrace._trace.trace_subscribers._base import TracingSubscriber # noqa: F401 + +# Import subscriber modules to trigger auto-registration via __init_subclass__ +import ddtrace._trace.trace_subscribers.http_client # noqa: F401 diff --git a/ddtrace/_trace/trace_subscribers/_base.py b/ddtrace/_trace/trace_subscribers/_base.py new file mode 100644 index 00000000000..72f23fe1143 --- /dev/null +++ b/ddtrace/_trace/trace_subscribers/_base.py @@ -0,0 +1,72 @@ +from types import TracebackType +from typing import Optional +from typing import Tuple + +from ddtrace._trace.trace_handlers import _finish_span +from ddtrace._trace.trace_handlers import _start_span +from ddtrace.internal import core + + +class TracingSubscriber: + """Base class for tracing event subscribers. + + Subclasses that define ``event_name`` auto-register on + context.started.{event_name} and context.ended.{event_name}. + """ + + event_name: str + + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + if "event_name" not in cls.__dict__: + return + core.on( + f"context.started.{cls.event_name}", + cls._on_context_started, + name=f"{cls.__name__}.started", + ) + core.on( + f"context.ended.{cls.event_name}", + cls._on_context_ended, + name=f"{cls.__name__}.ended", + ) + + @classmethod + def _on_context_started(cls, ctx: core.ExecutionContext, call_trace: bool = True, **kwargs) -> None: + cls.on_started(ctx, call_trace, **kwargs) + + @classmethod + def _on_context_ended( + cls, + ctx: core.ExecutionContext, + exc_info: Tuple[Optional[type], Optional[BaseException], Optional[TracebackType]], + ) -> None: + cls.on_ended(ctx, exc_info) + + @classmethod + def on_started(cls, ctx, call_trace=True, **kwargs): + pass + + @classmethod + def on_ended(cls, ctx, exc_info): + pass + + +class SpanTracingSubscriber(TracingSubscriber): + """Subscriber that creates a span on start and finishes it on end. + + Subclasses override on_started/on_ended for type-specific logic. + Span lifecycle is handled here — subclasses never call _start_span/_finish_span. + """ + + @classmethod + def _on_context_started(cls, ctx, call_trace=True, **kwargs): + _start_span(ctx, call_trace, **kwargs) + cls.on_started(ctx, call_trace, **kwargs) + + @classmethod + def _on_context_ended(cls, ctx, exc_info): + try: + cls.on_ended(ctx, exc_info) + finally: + _finish_span(ctx, exc_info) diff --git a/ddtrace/_trace/trace_subscribers/http_client.py b/ddtrace/_trace/trace_subscribers/http_client.py new file mode 100644 index 00000000000..220a584ce2f --- /dev/null +++ b/ddtrace/_trace/trace_subscribers/http_client.py @@ -0,0 +1,44 @@ +from ddtrace._trace.trace_subscribers._base import SpanTracingSubscriber +from ddtrace.contrib import trace_utils +from ddtrace.internal.logger import get_logger +from ddtrace.propagation.http import HTTPPropagator + + +log = get_logger(__name__) + + +class HttpClientTracingSubscriber(SpanTracingSubscriber): + """Shared tracing logic for ALL HTTP client integrations. + + httpx, requests, aiohttp, etc. all share this subscriber. + Adding a feature here applies to every HTTP client integration. + """ + + event_name = "http.client.request" + + @classmethod + def on_started(cls, ctx, call_trace=True, **kwargs): + span = ctx.span + integration_config = ctx.get_item("integration_config") + request_headers = ctx.get_item("request_headers") + if integration_config and request_headers is not None: + if trace_utils.distributed_tracing_enabled(integration_config): + HTTPPropagator.inject(span.context, request_headers) + + @classmethod + def on_ended(cls, ctx, exc_info): + span = ctx.span + try: + trace_utils.set_http_meta( + span, + ctx.get_item("integration_config"), + method=ctx.get_item("method"), + url=ctx.get_item("url"), + target_host=ctx.get_item("target_host"), + status_code=ctx.get_item("status_code"), + query=ctx.get_item("query"), + request_headers=ctx.get_item("request_headers"), + response_headers=ctx.get_item("response_headers"), + ) + except Exception: + log.debug("http.client: error adding tags", exc_info=True) diff --git a/ddtrace/appsec/_handlers.py b/ddtrace/appsec/_handlers.py index 4e315d40071..47d52dffdfe 100644 --- a/ddtrace/appsec/_handlers.py +++ b/ddtrace/appsec/_handlers.py @@ -8,7 +8,6 @@ from typing import Union from ddtrace._trace.span import Span -from ddtrace._trace.trace_handlers import httpx_url_to_str from ddtrace.appsec._asm_request_context import _call_waf from ddtrace.appsec._asm_request_context import _call_waf_first from ddtrace.appsec._asm_request_context import _get_asm_context @@ -25,6 +24,7 @@ from ddtrace.appsec._http_utils import parse_http_body from ddtrace.appsec._utils import Block_config from ddtrace.contrib import trace_utils +from ddtrace.contrib.internal.httpx.utils import httpx_url_to_str from ddtrace.contrib.internal.trace_utils_base import _get_request_header_user_agent from ddtrace.contrib.internal.trace_utils_base import _set_url_tag from ddtrace.ext import http @@ -615,7 +615,7 @@ def _on_httpx_request_ended(ctx: ExecutionContext, exc_info) -> None: return response = ctx.get_item("response") - if not response or (300 <= response.status_code < 400): + if response is None or (300 <= response.status_code < 400): return addresses = { @@ -659,8 +659,8 @@ def listen(): core.on("context.started.httpx.client._send_single_request", _on_httpx_client_send_single_request_started) core.on("context.ended.httpx.client._send_single_request", _on_httpx_client_send_single_request_ended) - core.on("context.started.httpx.request", _on_httpx_request_started) - core.on("context.ended.httpx.request", _on_httpx_request_ended) + core.on("context.started.http.client.request", _on_httpx_request_started) + core.on("context.ended.http.client.request", _on_httpx_request_ended) # disabling threats grpc listeners. # core.on("grpc.server.response.message", _on_grpc_server_response) diff --git a/ddtrace/contrib/events/__init__.py b/ddtrace/contrib/events/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/ddtrace/contrib/events/http_client.py b/ddtrace/contrib/events/http_client.py new file mode 100644 index 00000000000..40ba7d4f624 --- /dev/null +++ b/ddtrace/contrib/events/http_client.py @@ -0,0 +1,24 @@ +from typing import Optional + +from ddtrace.internal.core.events import TracedEvent +from ddtrace.internal.core.events import context_event +from ddtrace.internal.core.events import event_field + + +@context_event +class HttpClientRequestEvent(TracedEvent): + """HTTP client request event — pure data, no span manipulation. + + Integrations create this with library-specific data. + Tracing, AppSec, etc. subscribe via their own handlers. + """ + + event_name = "http.client.request" + + # HTTP-specific fields only — span metadata inherited from TracedEvent + url: str = event_field(in_context=True) + method: str = event_field(in_context=True) + target_host: Optional[str] = event_field(default=None, in_context=True) + query: Optional[object] = event_field(default=None, in_context=True) + request_headers: object = event_field(in_context=True) + request: object = event_field(in_context=True) diff --git a/ddtrace/contrib/internal/httpx/patch.py b/ddtrace/contrib/internal/httpx/patch.py index dd99bc532e9..2ab3610851f 100644 --- a/ddtrace/contrib/internal/httpx/patch.py +++ b/ddtrace/contrib/internal/httpx/patch.py @@ -10,6 +10,7 @@ from ddtrace import config from ddtrace.constants import SPAN_KIND +from ddtrace.contrib.events.http_client import HttpClientRequestEvent from ddtrace.contrib.internal.trace_utils import ext_service from ddtrace.ext import SpanKind from ddtrace.ext import SpanTypes @@ -24,6 +25,8 @@ from ddtrace.internal.utils.version import parse_version from ddtrace.internal.utils.wrappers import unwrap as _u +from .utils import httpx_url_to_str + HTTPX_VERSION = parse_version(httpx.__version__) HTTP_REQUEST_TAGS = {COMPONENT: config.httpx.integration_name, SPAN_KIND: SpanKind.CLIENT} @@ -94,16 +97,23 @@ async def _wrapped_async_send_single_request( async def _wrapped_async_send( wrapped: BoundFunctionWrapper, instance: httpx.AsyncClient, args: Tuple[httpx.Request], kwargs: Dict[str, Any] ): - req = get_argument_value(args, kwargs, 0, "request") - - with core.context_with_data( - "httpx.request", - call_trace=True, - span_name=schematize_url_operation("http.request", protocol="http", direction=SpanDirection.OUTBOUND), - span_type=SpanTypes.HTTP, - service=_get_service_name(req), - tags=HTTP_REQUEST_TAGS, - request=req, + req: httpx.Request = get_argument_value(args, kwargs, 0, "request") # type: ignore + + url_str = httpx_url_to_str(req.url) + with core.context_with_event( + HttpClientRequestEvent( + span_name=schematize_url_operation("http.request", protocol="http", direction=SpanDirection.OUTBOUND), + span_type=SpanTypes.HTTP, + service=_get_service_name(req), + tags=HTTP_REQUEST_TAGS, + integration_config=config.httpx, + url=url_str, + method=req.method, + target_host=req.url.host, + query=req.url.query, + request_headers=req.headers, + request=req, + ), ) as ctx: resp = None try: @@ -111,21 +121,30 @@ async def _wrapped_async_send( return resp finally: ctx.set_item("response", resp) + ctx.set_item("status_code", resp.status_code if resp else None) + ctx.set_item("response_headers", resp.headers if resp else None) def _wrapped_sync_send( wrapped: BoundFunctionWrapper, instance: httpx.AsyncClient, args: Tuple[httpx.Request], kwargs: Dict[str, Any] ): - req = get_argument_value(args, kwargs, 0, "request") - - with core.context_with_data( - "httpx.request", - call_trace=True, - span_name=schematize_url_operation("http.request", protocol="http", direction=SpanDirection.OUTBOUND), - span_type=SpanTypes.HTTP, - service=_get_service_name(req), - tags=HTTP_REQUEST_TAGS, - request=req, + req: httpx.Request = get_argument_value(args, kwargs, 0, "request") # type: ignore + + url_str = httpx_url_to_str(req.url) + with core.context_with_event( + HttpClientRequestEvent( + span_name=schematize_url_operation("http.request", protocol="http", direction=SpanDirection.OUTBOUND), + span_type=SpanTypes.HTTP, + service=_get_service_name(req), + tags=HTTP_REQUEST_TAGS, + integration_config=config.httpx, + url=url_str, + method=req.method, + target_host=req.url.host, + query=req.url.query, + request_headers=req.headers, + request=req, + ), ) as ctx: resp = None try: @@ -133,6 +152,8 @@ def _wrapped_sync_send( return resp finally: ctx.set_item("response", resp) + ctx.set_item("status_code", resp.status_code if resp else None) + ctx.set_item("response_headers", resp.headers if resp else None) def patch() -> None: diff --git a/ddtrace/contrib/internal/httpx/utils.py b/ddtrace/contrib/internal/httpx/utils.py new file mode 100644 index 00000000000..6c6d690c51c --- /dev/null +++ b/ddtrace/contrib/internal/httpx/utils.py @@ -0,0 +1,18 @@ +from ddtrace.internal.compat import ensure_binary +from ddtrace.internal.compat import ensure_text + + +def httpx_url_to_str(url) -> str: + """ + Helper to convert the httpx.URL parts from bytes to a str + """ + scheme = url.raw_scheme + host = url.raw_host + port = url.port + raw_path = url.raw_path + url = scheme + b"://" + host + if port is not None: + url += b":" + ensure_binary(str(port)) + url += raw_path + + return ensure_text(url) diff --git a/ddtrace/contrib/internal/requests/connection.py b/ddtrace/contrib/internal/requests/connection.py index 56f8ba1a247..78f0f4f2973 100644 --- a/ddtrace/contrib/internal/requests/connection.py +++ b/ddtrace/contrib/internal/requests/connection.py @@ -1,18 +1,17 @@ -from typing import Any # noqa:F401 -from typing import Dict # noqa:F401 from typing import Optional # noqa:F401 from urllib import parse import requests from ddtrace import config -from ddtrace._trace.pin import Pin -from ddtrace.constants import _SPAN_MEASURED_KEY +from ddtrace import tracer from ddtrace.constants import SPAN_KIND -from ddtrace.contrib import trace_utils +from ddtrace.contrib.events.http_client import HttpClientRequestEvent from ddtrace.contrib.internal.trace_utils import _sanitized_url +from ddtrace.contrib.internal.trace_utils import ext_service from ddtrace.ext import SpanKind from ddtrace.ext import SpanTypes +from ddtrace.internal import core from ddtrace.internal.constants import COMPONENT from ddtrace.internal.constants import USER_AGENT_HEADER from ddtrace.internal.logger import get_logger @@ -21,8 +20,6 @@ from ddtrace.internal.schema.span_attribute_schema import SpanDirection from ddtrace.internal.settings.asm import config as asm_config from ddtrace.internal.utils import get_argument_value -from ddtrace.propagation.http import HTTPPropagator -from ddtrace.trace import tracer log = get_logger(__name__) @@ -69,6 +66,13 @@ def _extract_query_string(uri): return uri[start:end] +def _get_service_name(request, hostname) -> Optional[str]: + if config.requests["split_by_domain"] and hostname: + return hostname + + return ext_service(None, config.requests) + + def _wrap_send(func, instance, args, kwargs): """Trace the `Session.send` instance method""" # skip if tracing is not enabled @@ -86,59 +90,28 @@ def _wrap_send(func, instance, args, kwargs): hostname, path = _extract_hostname_and_path(url) host_without_port = hostname.split(":")[0] if hostname is not None else None - cfg: Dict[str, Any] = {} - pin = Pin.get_from(instance) - if pin: - cfg = pin._config - - service = None - if cfg["split_by_domain"] and hostname: - service = hostname - if service is None: - service = cfg.get("service", None) - if service is None: - service = cfg.get("service_name", None) - if service is None: - service = trace_utils.ext_service(None, config.requests) - - operation_name = schematize_url_operation("requests.request", protocol="http", direction=SpanDirection.OUTBOUND) - with tracer.trace(operation_name, service=service, resource=f"{method} {path}", span_type=SpanTypes.HTTP) as span: - span._set_tag_str(COMPONENT, config.requests.integration_name) - - # set span.kind to the type of operation being performed - span._set_tag_str(SPAN_KIND, SpanKind.CLIENT) - - # PERF: avoid setting via Span.set_tag - span.set_metric(_SPAN_MEASURED_KEY, 1) - - # propagate distributed tracing headers - if cfg.get("distributed_tracing"): - HTTPPropagator.inject(span.context, request.headers) - - response = response_headers = None + with core.context_with_event( + HttpClientRequestEvent( + span_name=schematize_url_operation("requests.request", protocol="http", direction=SpanDirection.OUTBOUND), + span_type=SpanTypes.HTTP, + service=_get_service_name(request, hostname), + resource=f"{method} {path}", + tags={COMPONENT: config.requests.integration_name, SPAN_KIND: SpanKind.CLIENT}, + integration_config=config.requests, + url=url, + method=method, + target_host=host_without_port, + query=_extract_query_string(url), + request_headers=request.headers, + request=request, + ), + ) as ctx: + response = None try: response = func(*args, **kwargs) return response finally: - try: - status = None - if response is not None: - status = response.status_code - # Storing response headers in the span. - # Note that response.headers is not a dict, but an iterable - # requests custom structure, that we convert to a dict - response_headers = dict(getattr(response, "headers", {})) - - trace_utils.set_http_meta( - span, - config.requests, - request_headers=request.headers, - response_headers=response_headers, - method=method, - url=request.url, - target_host=host_without_port, - status_code=status, - query=_extract_query_string(url), - ) - except Exception: - log.debug("requests: error adding tags", exc_info=True) + ctx.set_item("response", response) + # Note: requests.Response is falsy for status >= 400, so use `is not None` + ctx.set_item("status_code", response.status_code if response is not None else None) + ctx.set_item("response_headers", dict(getattr(response, "headers", {})) if response is not None else None) diff --git a/ddtrace/internal/core/__init__.py b/ddtrace/internal/core/__init__.py index ee9454d5a7e..41225ad81d0 100644 --- a/ddtrace/internal/core/__init__.py +++ b/ddtrace/internal/core/__init__.py @@ -118,6 +118,7 @@ def _on_jsonify_context_started_flask(ctx): if typing.TYPE_CHECKING: from ddtrace._trace.span import Span # noqa:F401 + from ddtrace.internal.core.events import ContextEvent # noqa:F401 import contextvars @@ -287,6 +288,10 @@ def context_with_data(identifier, parent=None, **kwargs): return _CONTEXT_CLASS(identifier, parent=(parent or _CURRENT_CONTEXT.get()), **kwargs) +def context_with_event(event: "ContextEvent", parent=None): + return _CONTEXT_CLASS(event.event_name, parent=(parent or _CURRENT_CONTEXT.get()), **event.create_event_context()) + + def add_suppress_exception(exc_type: type) -> None: _CURRENT_CONTEXT.get()._suppress_exceptions.append(exc_type) diff --git a/ddtrace/internal/core/events.py b/ddtrace/internal/core/events.py new file mode 100644 index 00000000000..be5d3705dba --- /dev/null +++ b/ddtrace/internal/core/events.py @@ -0,0 +1,195 @@ +""" +Events API — an abstraction above the Core API for type-safe event dispatching. + +Events enforce the arguments that can be passed when dispatching, and allow better +correlation between dispatch sites and handlers. + +Example using ``context_with_event``:: + + @context_event + class MyEvent(ContextEvent): + event_name = "my.event" + + foo: str # automatically converted to event_field by @context_event + bar: str = event_field(in_context=True) # stored in ExecutionContext + + @classmethod + def _on_context_started(cls, ctx, call_trace=True, **kwargs): + print(ctx.get_item("bar")) + + with core.context_with_event(MyEvent(foo="hello", bar="world")): + pass + +Example using ``TracedEvent`` (pure data, tracing handled by subscribers):: + + @context_event + class HttpClientRequestEvent(TracedEvent): + event_name = "http.client.request" + url: str = event_field(in_context=True) + method: str = event_field(in_context=True) +""" + +from dataclasses import MISSING +from dataclasses import dataclass +from dataclasses import field +from dataclasses import fields +import sys +from types import TracebackType +from typing import Any +from typing import Dict +from typing import Optional +from typing import Tuple + +from ddtrace.internal import core + + +def context_event(cls: Any) -> Any: + """Decorator that converts a class into a dataclass with automatic event_field() defaults. + + By automatically applying event_field() to fields without defaults, this decorator allows child classes + to define required fields naturally while maintaining compatibility with parent class fields that have defaults. + + Example: + @context_event + class MyEvent(ContextEvent): + url: str # Automatically gets event_field() applied + """ + + annotations = cls.__dict__.get("__annotations__", {}) + + # For each annotated field without a default, set it to event_field() + for name, _ in annotations.items(): + # Only apply `event_field()` for fields that don't define a default value + # in the class body. + if name not in cls.__dict__: + setattr(cls, name, event_field()) + + return dataclass(cls) + + +def event_field( + default: Any = MISSING, + default_factory: Any = MISSING, + in_context: bool = False, +) -> Any: + """Creates a dataclass field with special handling for event context data and Python version compatibility. + Event fields ensure retro compatibility as python 3.9 does not support kw_only which is + required to allow a child class to have attributes without value. + + Args: + default: Default value for the field + default_factory: Factory function to generate default values + in_context: Whether this field should be included in the ExecutionContext dict + """ + if default is not MISSING and default_factory is not MISSING: + raise ValueError("Cannot specify both default and default_factory") + + kwargs: Dict[str, Any] = {"repr": in_context} + if default is not MISSING: + kwargs["default"] = default + elif default_factory is not MISSING: + kwargs["default_factory"] = default_factory + + # Python 3.9: Give fields without defaults a None default to work around + # field ordering constraints with inheritance + if sys.version_info < (3, 10): + if default is MISSING and default_factory is MISSING: + kwargs["default"] = None + else: + kwargs["kw_only"] = True + + return field(**kwargs) + + +@dataclass +class ContextEvent: + """Base class for context-based events used with ``core.context_with_event()``. + + Subclasses that define ``event_name`` auto-register their ``_on_context_started`` + and ``_on_context_ended`` hooks via ``__init_subclass__``. + + Should be decorated with ``@context_event``. + """ + + event_name: str = field(init=False, repr=False) + + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + + if "event_name" not in cls.__dict__: + return + + core.on( + f"context.started.{cls.event_name}", + cls._registered_context_started, + name=f"{cls.__name__}_started", + ) + core.on( + f"context.ended.{cls.event_name}", + cls._registered_context_ended, + name=f"{cls.__name__}_ended", + ) + + def create_event_context(self): + """Convert this event instance into a dict for creating an ExecutionContext. + Only event_field marked with in_context=True will be included. + """ + return {f.name: getattr(self, f.name) for f in fields(self) if f.repr} + + @classmethod + def _on_context_started(cls, ctx: core.ExecutionContext, call_trace: bool = True, **kwargs) -> None: + """Override in subclasses to handle context start.""" + pass + + @classmethod + def _on_context_ended( + cls, + ctx: core.ExecutionContext, + exc_info: Tuple[Optional[type], Optional[BaseException], Optional[TracebackType]], + ) -> None: + """Override in subclasses to handle context end.""" + pass + + @classmethod + def _registered_context_started(cls, ctx: core.ExecutionContext, call_trace: bool = True, **kwargs) -> None: + # _on_context_started will be called in order from parent class to children classes. + for base_cls in reversed(cls.__mro__[:-1]): + if issubclass(base_cls, ContextEvent) and "_on_context_started" in base_cls.__dict__: + base_cls._on_context_started(ctx, call_trace, **kwargs) + + @classmethod + def _registered_context_ended( + cls, + ctx: core.ExecutionContext, + exc_info: Tuple[Optional[type], Optional[BaseException], Optional[TracebackType]], + ) -> None: + # _on_context_ended will be called in order from parent class to children classes. + for base_cls in reversed(cls.__mro__[:-1]): + if issubclass(base_cls, ContextEvent) and "_on_context_ended" in base_cls.__dict__: + base_cls._on_context_ended(ctx, exc_info) + + +@dataclass +class TracedEvent(ContextEvent): + """Base for events that carry span metadata — pure data, no tracing logic. + + Domain-specific subclasses add their own fields. Tracing logic lives in + subscriber classes (ddtrace/_trace/trace_subscribers/). + + TracedEvent subclasses do NOT auto-register ContextEvent handlers — subscribers + handle all start/end logic independently. + """ + + def __init_subclass__(cls, **kwargs): + # Skip ContextEvent's auto-registration of _registered_context_started/ended. + # TracedEvent subclasses rely on TracingSubscriber for all event handling. + pass + + span_name: str = event_field(in_context=True) + span_type: str = event_field(in_context=True) + call_trace: bool = event_field(default=True, in_context=True) + service: Optional[str] = event_field(default=None, in_context=True) + resource: Optional[str] = event_field(default=None, in_context=True) + tags: Dict[str, str] = event_field(default_factory=dict, in_context=True) + measured: bool = event_field(default=True, in_context=True) + integration_config: object = event_field(in_context=True) diff --git a/tests/internal/events/test_context_event.py b/tests/internal/events/test_context_event.py new file mode 100644 index 00000000000..c665b3d202d --- /dev/null +++ b/tests/internal/events/test_context_event.py @@ -0,0 +1,205 @@ +import sys +from types import TracebackType +from typing import Optional +from typing import Tuple + +import pytest + +from ddtrace.internal import core +from ddtrace.internal.core import event_hub +from ddtrace.internal.core.events import ContextEvent +from ddtrace.internal.core.events import context_event +from ddtrace.internal.core.events import event_field + + +@pytest.fixture(autouse=True) +def reset_event_hub(): + """Reset event hub after each test to prevent listener leakage between tests.""" + yield + event_hub.reset() + + +def test_basic_context_event(): + """Test that ContextEvent triggers _on_context_started and _on_context_ended hooks.""" + called = [] + + @context_event + class TestContextEvent(ContextEvent): + event_name = "test.event" + + @classmethod + def _on_context_started(cls, ctx: core.ExecutionContext, call_trace: bool = True, **kwargs) -> None: + called.append("started") + + @classmethod + def _on_context_ended( + cls, + ctx: core.ExecutionContext, + exc_info: Tuple[Optional[type], Optional[BaseException], Optional[TracebackType]], + ) -> None: + called.append("ended") + + with core.context_with_event(TestContextEvent()): + pass + + assert called == ["started", "ended"] + + +def test_context_event_double_dispatch(): + """Test that dispatching the same context event twice calls hooks twice but registers only once.""" + called = [] + + @context_event + class TestContextEvent(ContextEvent): + event_name = "test.event" + + @classmethod + def _on_context_started(cls, ctx: core.ExecutionContext, call_trace: bool = True, **kwargs) -> None: + called.append("started") + + @classmethod + def _on_context_ended( + cls, + ctx: core.ExecutionContext, + exc_info: Tuple[Optional[type], Optional[BaseException], Optional[TracebackType]], + ) -> None: + called.append("ended") + + with core.context_with_event(TestContextEvent()): + pass + with core.context_with_event(TestContextEvent()): + pass + + assert called == ["started", "ended", "started", "ended"] + + # Ensure that we register test.event only once + from ddtrace.internal.core.event_hub import _listeners + + assert len(_listeners[f"context.started.{TestContextEvent.event_name}"].values()) == 1 + assert len(_listeners[f"context.ended.{TestContextEvent.event_name}"].values()) == 1 + + +@pytest.mark.skipif(sys.version_info < (3, 10), reason="Requires Python 3.10+") +def test_context_event_enforce_kwargs_error(): + """Test that missing required fields raise TypeError. + On Python 3.9, we create a default value to every attributes because kw_only + is not available in dataclass field. Therefore we skip the test + """ + called = [] + + @context_event + class TestContextEvent(ContextEvent): + event_name = "test.event" + foo: str + bar: int + + @classmethod + def _on_context_started(cls, ctx: core.ExecutionContext, call_trace: bool = True, **kwargs) -> None: + called.append("started") + + @classmethod + def _on_context_ended( + cls, + ctx: core.ExecutionContext, + exc_info: Tuple[Optional[type], Optional[BaseException], Optional[TracebackType]], + ) -> None: + called.append("ended") + + with pytest.raises(TypeError): + with core.context_with_event(TestContextEvent(foo="toto")): + pass + + assert called == [] + + +def test_context_event_event_field(): + """Test that missing required fields raise TypeError.""" + called = [] + + @context_event + class TestContextEvent(ContextEvent): + event_name = "test.event" + foo: str = event_field(in_context=True) + not_in_context: int + with_default: str = event_field(default="test", in_context=True) + + @classmethod + def _on_context_started(cls, ctx: core.ExecutionContext, call_trace: bool = True, **kwargs) -> None: + called.append("started") + called.append(ctx.get_item("foo")) + called.append(ctx.get_item("with_default")) + + assert ctx.get_item("not_in_context") is None + + with core.context_with_event(TestContextEvent(foo="toto", not_in_context=0)): + pass + + assert called == ["started", "toto", "test"] + + +def test_content_event_inheriting_context_event(): + """Test that child ContextEvent inherits and extends parent's hooks.""" + called = [] + + @context_event + class TestContextEvent(ContextEvent): + event_name = "test.event" + foo: str = event_field(in_context=True) + + @classmethod + def _on_context_started(cls, ctx: core.ExecutionContext, call_trace: bool = True, **kwargs) -> None: + called.append("base_started") + + @classmethod + def _on_context_ended( + cls, + ctx: core.ExecutionContext, + exc_info: Tuple[Optional[type], Optional[BaseException], Optional[TracebackType]], + ) -> None: + called.append("base_ended") + + @context_event + class ChildTestContextEvent(TestContextEvent): + event_name = "test.child.event" + + @classmethod + def _on_context_started(cls, ctx: core.ExecutionContext, call_trace: bool = True, **kwargs) -> None: + called.append("child_started") + + @classmethod + def _on_context_ended( + cls, + ctx: core.ExecutionContext, + exc_info: Tuple[Optional[type], Optional[BaseException], Optional[TracebackType]], + ) -> None: + called.append("child_ended") + called.append(ctx.get_item("foo")) + + with core.context_with_event(ChildTestContextEvent(foo="toto")): + pass + + assert called == ["base_started", "child_started", "base_ended", "child_ended", "toto"] + + +def test_context_event_with_exception(): + """Test that exception info is properly passed to _on_context_ended.""" + called = [] + + @context_event + class TestContextEvent(ContextEvent): + event_name = "test.event" + + @classmethod + def _on_context_ended( + cls, + ctx: core.ExecutionContext, + exc_info: Tuple[Optional[type], Optional[BaseException], Optional[TracebackType]], + ) -> None: + called.append(exc_info) + + with pytest.raises(ValueError): + with core.context_with_event(TestContextEvent()): + raise ValueError("test error") + + assert called[0][0] == ValueError + assert str(called[0][1]) == "test error" diff --git a/tests/internal/test_module.py b/tests/internal/test_module.py index 9798efd2749..ad85d6a3341 100644 --- a/tests/internal/test_module.py +++ b/tests/internal/test_module.py @@ -555,6 +555,9 @@ def test_public_modules_in_ddtrace_contrib(): if "internal" in relative_dir.parts: # ignore modules in ddtrace/contrib/internal continue + if "events" in relative_dir.parts: + # ignore modules in ddtrace/contrib/events + continue for file_name in file_names: if file_name.endswith(".py"):