Skip to content

Commit c50fda7

Browse files
committed
Refine reentrance check to handle import/export forwarding
1 parent 7479890 commit c50fda7

File tree

1 file changed

+28
-13
lines changed

1 file changed

+28
-13
lines changed

design/mvp/canonical-abi/definitions.py

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -189,8 +189,11 @@ class Store:
189189
def __init__(self):
190190
self.pending = []
191191

192-
def invoke(self, f: FuncInst, caller, on_start, on_resolve) -> Call:
193-
return f(caller, on_start, on_resolve)
192+
def invoke(self, f: FuncInst, caller: Optional[Supertask], on_start, on_resolve) -> Call:
193+
host_caller = Supertask()
194+
host_caller.inst = None
195+
host_caller.supertask = caller
196+
return f(host_caller, on_start, on_resolve)
194197

195198
def tick(self):
196199
random.shuffle(self.pending)
@@ -205,7 +208,7 @@ def tick(self):
205208
OnResolve = Callable[[Optional[list[any]]], None]
206209

207210
class Supertask:
208-
inst: ComponentInstance
211+
inst: Optional[ComponentInstance]
209212
supertask: Optional[Supertask]
210213

211214
class Call:
@@ -252,20 +255,38 @@ class CanonicalOptions(LiftLowerOptions):
252255

253256
class ComponentInstance:
254257
store: Store
258+
parent: Optional[ComponentInstance]
255259
table: Table
256260
may_leave: bool
257261
backpressure: int
258262
exclusive: bool
259263
num_waiting_to_enter: int
260264

261-
def __init__(self, store):
265+
def __init__(self, store, parent = None):
266+
assert(parent is None or parent.store is store)
262267
self.store = store
268+
self.parent = parent
263269
self.table = Table()
264270
self.may_leave = True
265271
self.backpressure = 0
266272
self.exclusive = False
267273
self.num_waiting_to_enter = 0
268274

275+
def ancestors(inst: Optional[ComponentInstance]) -> set[ComponentInstance]:
276+
s = set()
277+
while inst is not None:
278+
s.add(inst)
279+
inst = inst.parent
280+
return s
281+
282+
def call_is_recursive(caller: Supertask, callee_inst: ComponentInstance):
283+
callee_insts = { callee_inst } & (ancestors(callee_inst) - ancestors(caller.inst))
284+
while caller is not None:
285+
if callee_insts & ancestors(caller.inst):
286+
return True
287+
caller = caller.supertask
288+
return False
289+
269290
#### Table State
270291

271292
class Table:
@@ -534,7 +555,7 @@ class State(Enum):
534555
opts: CanonicalOptions
535556
inst: ComponentInstance
536557
ft: FuncType
537-
supertask: Optional[Task]
558+
supertask: Task
538559
on_resolve: OnResolve
539560
num_borrows: int
540561
threads: list[Thread]
@@ -560,12 +581,6 @@ def thread_stop(self, thread):
560581
trap_if(self.state != Task.State.RESOLVED)
561582
assert(self.num_borrows == 0)
562583

563-
def trap_if_on_the_stack(self, inst):
564-
c = self.supertask
565-
while c is not None:
566-
trap_if(c.inst is inst)
567-
c = c.supertask
568-
569584
def needs_exclusive(self):
570585
return not self.opts.async_ or self.opts.callback
571586

@@ -1984,8 +1999,8 @@ def lower_flat_values(cx, max_flat, vs, ts, out_param = None):
19841999
### `canon lift`
19852000

19862001
def canon_lift(opts, inst, ft, callee, caller, on_start, on_resolve) -> Call:
2002+
trap_if(call_is_recursive(caller, inst))
19872003
task = Task(opts, inst, ft, caller, on_resolve)
1988-
task.trap_if_on_the_stack(inst)
19892004
def thread_func(thread):
19902005
if not task.enter(thread):
19912006
return
@@ -2167,7 +2182,7 @@ def canon_resource_drop(rt, thread, i):
21672182
callee = partial(canon_lift, callee_opts, rt.impl, ft, rt.dtor)
21682183
[] = canon_lower(caller_opts, ft, callee, thread, [h.rep])
21692184
else:
2170-
thread.task.trap_if_on_the_stack(rt.impl)
2185+
trap_if(call_is_recursive(thread.task, rt.impl))
21712186
else:
21722187
h.borrow_scope.num_borrows -= 1
21732188
return []

0 commit comments

Comments
 (0)