diff --git a/scripts/gen_schema.py b/scripts/gen_schema.py index 9a181c9..9a78b26 100644 --- a/scripts/gen_schema.py +++ b/scripts/gen_schema.py @@ -116,6 +116,30 @@ ("ToolCallUpdate", "status", "ToolCallStatus", True), ) + +@dataclass(frozen=True) +class FieldValidatorInjection: + """A generated field validator that should be appended to one schema class.""" + + class_name: str + field_name: str + method_name: str + argument_name: str + return_type: str + comment_lines: tuple[str, ...] + body_lines: tuple[str, ...] + + def render(self) -> str: + lines = [ + f'@field_validator("{self.field_name}", mode="before")', + "@classmethod", + f"def {self.method_name}(cls, {self.argument_name}: Any) -> {self.return_type}:", + ] + lines.extend(f" # {line}" for line in self.comment_lines) + lines.extend(f" {line}" for line in self.body_lines) + return "\n".join(lines) + + DEFAULT_VALUE_OVERRIDES: tuple[tuple[str, str, str], ...] = ( ("AgentCapabilities", "mcp_capabilities", "McpCapabilities()"), ("AgentCapabilities", "session_capabilities", "SessionCapabilities()"), @@ -138,6 +162,31 @@ ), ) +# Classes that need a field_validator injected after generation. +CLASS_VALIDATOR_INJECTIONS: tuple[FieldValidatorInjection, ...] = ( + FieldValidatorInjection( + class_name="InitializeRequest", + field_name="protocol_version", + method_name="_coerce_protocol_version", + argument_name="value", + return_type="int", + comment_lines=( + 'Some clients (e.g. Zed) send a date string like "2024-11-05" instead', + "of an integer. The Rust SDK treats legacy strings as version 0; this", + "SDK maps unparsable values to 1 so the connection is not rejected.", + "See: https://github.com/agentclientprotocol/rust-sdk/blob/main/crates/agent-client-protocol-schema/src/version.rs", + ), + body_lines=( + "if isinstance(value, int):", + " return value", + "try:", + " return int(value)", + "except (TypeError, ValueError):", + " return 1", + ), + ), +) + @dataclass(frozen=True) class _ProcessingStep: @@ -200,6 +249,7 @@ def postprocess_generated_schema(output_path: Path) -> list[str]: _ProcessingStep("apply default overrides", _apply_default_overrides), _ProcessingStep("attach description comments", _add_description_comments), _ProcessingStep("ensure custom BaseModel", _ensure_custom_base_model), + _ProcessingStep("inject field validators", _inject_field_validators), ) for step in processing_steps: @@ -356,6 +406,47 @@ def __getattr__(self, item: str) -> Any: return "\n".join(lines) + "\n" +def _ensure_pydantic_import(content: str, name: str) -> str: + """Add *name* to the ``from pydantic import ...`` line if not already present.""" + lines = content.splitlines() + for idx, line in enumerate(lines): + if not line.startswith("from pydantic import "): + continue + imports = [part.strip() for part in line[len("from pydantic import ") :].split(",")] + if name not in imports: + imports.append(name) + lines[idx] = "from pydantic import " + ", ".join(imports) + return "\n".join(lines) + "\n" + return content + + +def _inject_field_validators(content: str) -> str: + """Inject field_validator methods into classes listed in CLASS_VALIDATOR_INJECTIONS.""" + for injection in CLASS_VALIDATOR_INJECTIONS: + content = _ensure_pydantic_import(content, "field_validator") + + class_pattern = re.compile( + rf"(class {injection.class_name}\(BaseModel\):)(.*?)(?=\nclass |\Z)", + re.DOTALL, + ) + + def _append_validator( + match: re.Match[str], + _injection: FieldValidatorInjection = injection, + ) -> str: + header, block = match.group(1), match.group(2) + indented = "\n" + textwrap.indent(_injection.render(), " ") + return header + block + indented + "\n" + + content, count = class_pattern.subn(_append_validator, content, count=1) + if count == 0: + print( + f"Warning: class {injection.class_name} not found for validator injection", + file=sys.stderr, + ) + return content + + def _apply_field_overrides(content: str) -> str: for class_name, field_name, new_type, optional in FIELD_TYPE_OVERRIDES: if optional: diff --git a/src/acp/schema.py b/src/acp/schema.py index e942245..614c7ed 100644 --- a/src/acp/schema.py +++ b/src/acp/schema.py @@ -6,7 +6,7 @@ from enum import Enum from typing import Annotated, Any, Dict, List, Literal, Optional, Union -from pydantic import AnyUrl, BaseModel as _BaseModel, Field, RootModel, ConfigDict +from pydantic import AnyUrl, BaseModel as _BaseModel, Field, RootModel, ConfigDict, field_validator PermissionOptionKind = Literal["allow_once", "allow_always", "reject_once", "reject_always"] PlanEntryPriority = Literal["high", "medium", "low"] @@ -3922,6 +3922,20 @@ class InitializeRequest(BaseModel): ), ] + @field_validator("protocol_version", mode="before") + @classmethod + def _coerce_protocol_version(cls, value: Any) -> int: + # Some clients (e.g. Zed) send a date string like "2024-11-05" instead + # of an integer. The Rust SDK treats legacy strings as version 0; this + # SDK maps unparsable values to 1 so the connection is not rejected. + # See: https://github.com/agentclientprotocol/rust-sdk/blob/main/crates/agent-client-protocol-schema/src/version.rs + if isinstance(value, int): + return value + try: + return int(value) + except (TypeError, ValueError): + return 1 + class LoadSessionRequest(BaseModel): # The _meta property is reserved by ACP to allow clients and agents to attach additional diff --git a/tests/test_rpc.py b/tests/test_rpc.py index 19bdc6c..32c7464 100644 --- a/tests/test_rpc.py +++ b/tests/test_rpc.py @@ -236,8 +236,8 @@ async def test_invalid_params_results_in_error_response(connect, server): # Only start agent-side (server) so we can inject raw request from client socket connect(connect_agent=True, connect_client=False) - # Send initialize with wrong param type (protocolVersion should be int) - req = {"jsonrpc": "2.0", "id": 1, "method": "initialize", "params": {"protocolVersion": "oops"}} + # Send initialize without the required protocolVersion field. + req = {"jsonrpc": "2.0", "id": 1, "method": "initialize", "params": {}} server.client_writer.write((json.dumps(req) + "\n").encode()) await server.client_writer.drain() @@ -249,6 +249,22 @@ async def test_invalid_params_results_in_error_response(connect, server): assert resp["error"]["code"] == -32602 # invalid params +@pytest.mark.asyncio +async def test_initialize_accepts_legacy_string_protocol_version(connect, server): + # Only start agent-side (server) so we can inject raw request from client socket. + connect(connect_agent=True, connect_client=False) + + req = {"jsonrpc": "2.0", "id": 1, "method": "initialize", "params": {"protocolVersion": "2024-11-05"}} + server.client_writer.write((json.dumps(req) + "\n").encode()) + await server.client_writer.drain() + + line = await asyncio.wait_for(server.client_reader.readline(), timeout=1) + resp = json.loads(line) + assert resp["id"] == 1 + assert "error" not in resp + assert resp["result"]["protocolVersion"] == 1 + + @pytest.mark.asyncio async def test_method_not_found_results_in_error_response(connect, server): connect(connect_agent=True, connect_client=False)