Skip to content

Commit 7387014

Browse files
committed
fix: support optionally_keyed_by with underlying dict
1 parent 1b24dc3 commit 7387014

File tree

2 files changed

+90
-13
lines changed

2 files changed

+90
-13
lines changed

src/taskgraph/util/schema.py

Lines changed: 72 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import threading
88
from collections.abc import Mapping
99
from functools import reduce
10-
from typing import Any, Literal, Optional, Union
10+
from typing import Annotated, Any, Literal, Optional, Union, get_args, get_origin
1111

1212
import msgspec
1313
import voluptuous
@@ -70,11 +70,40 @@ def validate_schema(schema, obj, msg_prefix):
7070
raise Exception(f"{msg_prefix}\n{str(exc)}\n{pprint.pformat(obj)}")
7171

7272

73-
def UnionTypes(*types):
74-
"""Use `functools.reduce` to simulate `Union[*allowed_types]` on older
75-
Python versions.
76-
"""
77-
return reduce(lambda a, b: Union[a, b], types)
73+
class OptionallyKeyedBy:
74+
"""Metadata class for optionally_keyed_by fields in msgspec schemas."""
75+
76+
def __init__(self, *fields, wrapped_type):
77+
self.fields = fields
78+
self.wrapped_type = wrapped_type
79+
80+
@classmethod
81+
def uses_keyed_by(cls, obj) -> bool:
82+
if not isinstance(obj, dict) or len(obj) != 1:
83+
return False
84+
85+
key = list(obj)[0]
86+
if not key.startswith("by-"):
87+
return False
88+
89+
return True
90+
91+
def validate(self, obj) -> None:
92+
if not self.uses_keyed_by(obj):
93+
# Not using keyed by, validate directly against wrapped type
94+
msgspec.convert(obj, self.wrapped_type)
95+
return
96+
97+
# First validate the outer keyed-by dict
98+
bykeys = UnionTypes(*[Literal[f"by-{field}"] for field in self.fields])
99+
msgspec.convert(obj, dict[bykeys, dict])
100+
101+
# Next validate each inner value. We call self.validate recursively to
102+
# support nested `by-*` keys.
103+
keyed_by_dict = list(obj.values())[0]
104+
for value in keyed_by_dict.values():
105+
self.validate(value)
106+
78107

79108

80109
def optionally_keyed_by(*arguments, use_msgspec=False):
@@ -86,13 +115,15 @@ def optionally_keyed_by(*arguments, use_msgspec=False):
86115
use_msgspec: If True, return msgspec type hints; if False, return voluptuous validator
87116
"""
88117
if use_msgspec:
89-
# msgspec implementation - return type hints
118+
# msgspec implementation - use Annotated[Any, OptionallyKeyedBy]
90119
_type = arguments[-1]
91120
if _type is object:
92121
return object
93122
fields = arguments[:-1]
94-
bykeys = [Literal[f"by-{field}"] for field in fields]
95-
return Union[_type, dict[UnionTypes(*bykeys), dict[str, Any]]]
123+
wrapper = OptionallyKeyedBy(*fields, wrapped_type=_type)
124+
# Annotating Any allows msgspec to accept any value without validation.
125+
# The actual validation then happens in Schema.__post_init__
126+
return Annotated[Any, wrapper]
96127
else:
97128
# voluptuous implementation - return validator function
98129
schema = arguments[-1]
@@ -291,6 +322,13 @@ def __getitem__(self, item):
291322
return self.schema[item] # type: ignore
292323

293324

325+
def UnionTypes(*types):
326+
"""Use `functools.reduce` to simulate `Union[*allowed_types]` on older
327+
Python versions.
328+
"""
329+
return reduce(lambda a, b: Union[a, b], types)
330+
331+
294332
class Schema(
295333
msgspec.Struct,
296334
kw_only=True,
@@ -318,6 +356,31 @@ class MySchema(Schema, forbid_unknown_fields=False, kw_only=True):
318356
foo: str
319357
"""
320358

359+
def __post_init__(self):
360+
if taskgraph.fast:
361+
return
362+
363+
# Validate fields that use optionally_keyed_by. We need to validate this
364+
# manually because msgspec doesn't support union types with multiple
365+
# dicts. Any fields that use `optionally_keyed_by("foo", dict)` would
366+
# otherwise raise an exception.
367+
for field_name, field_type in self.__class__.__annotations__.items():
368+
origin = get_origin(field_type)
369+
args = get_args(field_type)
370+
371+
if (
372+
origin is not Annotated
373+
or len(args) < 2
374+
or not isinstance(args[1], OptionallyKeyedBy)
375+
):
376+
# Not using `optionally_keyed_by`
377+
continue
378+
379+
keyed_by = args[1]
380+
obj = getattr(self, field_name)
381+
382+
keyed_by.validate(obj)
383+
321384
@classmethod
322385
def validate(cls, data):
323386
"""Validate data against this schema."""

test/test_util_schema.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -290,24 +290,36 @@ class TestSchema(Schema):
290290
TestSchema.validate({"field": "baz"})
291291
TestSchema.validate({"field": {"by-foo": {"a": "b", "c": "d"}}})
292292

293-
# Inner dict values are Any, so mixed types are accepted
294-
TestSchema.validate({"field": {"by-foo": {"a": 1, "c": "d"}}})
293+
with pytest.raises(msgspec.ValidationError):
294+
TestSchema.validate({"field": 1})
295+
296+
with pytest.raises(msgspec.ValidationError):
297+
TestSchema.validate({"field": {"by-bar": "a"}})
298+
299+
with pytest.raises(msgspec.ValidationError):
300+
TestSchema.validate({"field": {"by-bar": {1: "b"}}})
295301

296302
with pytest.raises(msgspec.ValidationError):
297303
TestSchema.validate({"field": {"by-bar": {"a": "b"}}})
298304

305+
with pytest.raises(msgspec.ValidationError):
306+
TestSchema.validate({"field": {"by-foo": {"a": 1, "c": "d"}}})
307+
299308

300-
def test_optionally_keyed_by_mulitple_keys():
309+
def test_optionally_keyed_by_multiple_keys():
301310
class TestSchema(Schema):
302311
field: optionally_keyed_by("foo", "bar", str, use_msgspec=True) # type: ignore
303312

304313
TestSchema.validate({"field": {"by-foo": {"a": "b"}}})
305314
TestSchema.validate({"field": {"by-bar": {"x": "y"}}})
315+
TestSchema.validate({"field": {"by-foo": {"a": {"by-bar": {"x": "y"}}}}})
306316

307317
# Test invalid keyed-by field
308318
with pytest.raises(msgspec.ValidationError):
309319
TestSchema.validate({"field": {"by-unknown": {"a": "b"}}})
310320

321+
with pytest.raises(msgspec.ValidationError):
322+
TestSchema.validate({"field": {"by-foo": {"a": {"by-bar": {"x": 1}}}}})
311323

312324
def test_optionally_keyed_by_object_passthrough():
313325
"""When the type argument is `object`, optionally_keyed_by returns object directly."""
@@ -320,14 +332,16 @@ def test_optionally_keyed_by_object_passthrough():
320332
assert msgspec.convert({"arbitrary": "dict"}, typ) == {"arbitrary": "dict"}
321333

322334

323-
@pytest.mark.xfail
324335
def test_optionally_keyed_by_dict():
325336
class TestSchema(Schema):
326337
field: optionally_keyed_by("foo", dict[str, str], use_msgspec=True) # type: ignore
327338

328339
TestSchema.validate({"field": {"by-foo": {"a": {"x": "y"}}}})
329340
TestSchema.validate({"field": {"a": "b"}})
330341

342+
with pytest.raises(msgspec.ValidationError):
343+
TestSchema.validate({"field": {"a": 1}})
344+
331345
with pytest.raises(msgspec.ValidationError):
332346
TestSchema.validate({"field": {"by-foo": {"a": {"x": 1}}}})
333347

0 commit comments

Comments
 (0)