77import threading
88from collections .abc import Mapping
99from functools import reduce
10- from typing import Any , Literal , Optional , Union
10+ from typing import Annotated , Any , Literal , Optional , Union , get_args , get_origin
1111
1212import msgspec
1313import voluptuous
@@ -70,11 +70,40 @@ def validate_schema(schema, obj, msg_prefix):
7070 raise Exception (f"{ msg_prefix } \n { str (exc )} \n { pprint .pformat (obj )} " )
7171
7272
73- def UnionTypes (* types ):
74- """Use `functools.reduce` to simulate `Union[*allowed_types]` on older
75- Python versions.
76- """
77- return reduce (lambda a , b : Union [a , b ], types )
73+ class OptionallyKeyedBy :
74+ """Metadata class for optionally_keyed_by fields in msgspec schemas."""
75+
76+ def __init__ (self , * fields , wrapped_type ):
77+ self .fields = fields
78+ self .wrapped_type = wrapped_type
79+
80+ @classmethod
81+ def uses_keyed_by (cls , obj ) -> bool :
82+ if not isinstance (obj , dict ) or len (obj ) != 1 :
83+ return False
84+
85+ key = list (obj )[0 ]
86+ if not key .startswith ("by-" ):
87+ return False
88+
89+ return True
90+
91+ def validate (self , obj ) -> None :
92+ if not self .uses_keyed_by (obj ):
93+ # Not using keyed by, validate directly against wrapped type
94+ msgspec .convert (obj , self .wrapped_type )
95+ return
96+
97+ # First validate the outer keyed-by dict
98+ bykeys = UnionTypes (* [Literal [f"by-{ field } " ] for field in self .fields ])
99+ msgspec .convert (obj , dict [bykeys , dict ])
100+
101+ # Next validate each inner value. We call self.validate recursively to
102+ # support nested `by-*` keys.
103+ keyed_by_dict = list (obj .values ())[0 ]
104+ for value in keyed_by_dict .values ():
105+ self .validate (value )
106+
78107
79108
80109def optionally_keyed_by (* arguments , use_msgspec = False ):
@@ -86,13 +115,15 @@ def optionally_keyed_by(*arguments, use_msgspec=False):
86115 use_msgspec: If True, return msgspec type hints; if False, return voluptuous validator
87116 """
88117 if use_msgspec :
89- # msgspec implementation - return type hints
118+ # msgspec implementation - use Annotated[Any, OptionallyKeyedBy]
90119 _type = arguments [- 1 ]
91120 if _type is object :
92121 return object
93122 fields = arguments [:- 1 ]
94- bykeys = [Literal [f"by-{ field } " ] for field in fields ]
95- return Union [_type , dict [UnionTypes (* bykeys ), dict [str , Any ]]]
123+ wrapper = OptionallyKeyedBy (* fields , wrapped_type = _type )
124+ # Annotating Any allows msgspec to accept any value without validation.
125+ # The actual validation then happens in Schema.__post_init__
126+ return Annotated [Any , wrapper ]
96127 else :
97128 # voluptuous implementation - return validator function
98129 schema = arguments [- 1 ]
@@ -291,6 +322,13 @@ def __getitem__(self, item):
291322 return self .schema [item ] # type: ignore
292323
293324
325+ def UnionTypes (* types ):
326+ """Use `functools.reduce` to simulate `Union[*allowed_types]` on older
327+ Python versions.
328+ """
329+ return reduce (lambda a , b : Union [a , b ], types )
330+
331+
294332class Schema (
295333 msgspec .Struct ,
296334 kw_only = True ,
@@ -318,6 +356,31 @@ class MySchema(Schema, forbid_unknown_fields=False, kw_only=True):
318356 foo: str
319357 """
320358
359+ def __post_init__ (self ):
360+ if taskgraph .fast :
361+ return
362+
363+ # Validate fields that use optionally_keyed_by. We need to validate this
364+ # manually because msgspec doesn't support union types with multiple
365+ # dicts. Any fields that use `optionally_keyed_by("foo", dict)` would
366+ # otherwise raise an exception.
367+ for field_name , field_type in self .__class__ .__annotations__ .items ():
368+ origin = get_origin (field_type )
369+ args = get_args (field_type )
370+
371+ if (
372+ origin is not Annotated
373+ or len (args ) < 2
374+ or not isinstance (args [1 ], OptionallyKeyedBy )
375+ ):
376+ # Not using `optionally_keyed_by`
377+ continue
378+
379+ keyed_by = args [1 ]
380+ obj = getattr (self , field_name )
381+
382+ keyed_by .validate (obj )
383+
321384 @classmethod
322385 def validate (cls , data ):
323386 """Validate data against this schema."""
0 commit comments