@@ -189,8 +189,11 @@ class Store:
189189 def __init__ (self ):
190190 self .pending = []
191191
192- def invoke (self , f : FuncInst , caller , on_start , on_resolve ) -> Call :
193- return f (caller , on_start , on_resolve )
192+ def invoke (self , f : FuncInst , caller : Optional [Supertask ], on_start , on_resolve ) -> Call :
193+ host_caller = Supertask ()
194+ host_caller .inst = None
195+ host_caller .supertask = caller
196+ return f (host_caller , on_start , on_resolve )
194197
195198 def tick (self ):
196199 random .shuffle (self .pending )
@@ -205,7 +208,7 @@ def tick(self):
205208OnResolve = Callable [[Optional [list [any ]]], None ]
206209
207210class Supertask :
208- inst : ComponentInstance
211+ inst : Optional [ ComponentInstance ]
209212 supertask : Optional [Supertask ]
210213
211214class Call :
@@ -252,20 +255,38 @@ class CanonicalOptions(LiftLowerOptions):
252255
253256class ComponentInstance :
254257 store : Store
258+ parent : Optional [ComponentInstance ]
255259 table : Table
256260 may_leave : bool
257261 backpressure : int
258262 exclusive : bool
259263 num_waiting_to_enter : int
260264
261- def __init__ (self , store ):
265+ def __init__ (self , store , parent = None ):
266+ assert (parent is None or parent .store is store )
262267 self .store = store
268+ self .parent = parent
263269 self .table = Table ()
264270 self .may_leave = True
265271 self .backpressure = 0
266272 self .exclusive = False
267273 self .num_waiting_to_enter = 0
268274
275+ def ancestors (inst : Optional [ComponentInstance ]) -> set [ComponentInstance ]:
276+ s = set ()
277+ while inst is not None :
278+ s .add (inst )
279+ inst = inst .parent
280+ return s
281+
282+ def call_is_recursive (caller : Supertask , callee_inst : ComponentInstance ):
283+ callee_insts = { callee_inst } & (ancestors (callee_inst ) - ancestors (caller .inst ))
284+ while caller is not None :
285+ if callee_insts & ancestors (caller .inst ):
286+ return True
287+ caller = caller .supertask
288+ return False
289+
269290#### Table State
270291
271292class Table :
@@ -534,7 +555,7 @@ class State(Enum):
534555 opts : CanonicalOptions
535556 inst : ComponentInstance
536557 ft : FuncType
537- supertask : Optional [ Task ]
558+ supertask : Task
538559 on_resolve : OnResolve
539560 num_borrows : int
540561 threads : list [Thread ]
@@ -560,12 +581,6 @@ def thread_stop(self, thread):
560581 trap_if (self .state != Task .State .RESOLVED )
561582 assert (self .num_borrows == 0 )
562583
563- def trap_if_on_the_stack (self , inst ):
564- c = self .supertask
565- while c is not None :
566- trap_if (c .inst is inst )
567- c = c .supertask
568-
569584 def needs_exclusive (self ):
570585 return not self .opts .async_ or self .opts .callback
571586
@@ -1984,8 +1999,8 @@ def lower_flat_values(cx, max_flat, vs, ts, out_param = None):
19841999### `canon lift`
19852000
19862001def canon_lift (opts , inst , ft , callee , caller , on_start , on_resolve ) -> Call :
2002+ trap_if (call_is_recursive (caller , inst ))
19872003 task = Task (opts , inst , ft , caller , on_resolve )
1988- task .trap_if_on_the_stack (inst )
19892004 def thread_func (thread ):
19902005 if not task .enter (thread ):
19912006 return
@@ -2167,7 +2182,7 @@ def canon_resource_drop(rt, thread, i):
21672182 callee = partial (canon_lift , callee_opts , rt .impl , ft , rt .dtor )
21682183 [] = canon_lower (caller_opts , ft , callee , thread , [h .rep ])
21692184 else :
2170- thread .task . trap_if_on_the_stack ( rt .impl )
2185+ trap_if ( call_is_recursive ( thread .task , rt .impl ) )
21712186 else :
21722187 h .borrow_scope .num_borrows -= 1
21732188 return []
0 commit comments