Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 20 additions & 20 deletions kpops/pipeline/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,17 +104,19 @@ def build_execution_graph(
/,
reverse: bool = False,
) -> Awaitable[None]:
async def run_parallel_tasks(
coroutines: list[Coroutine[Any, Any, None]],
async def run_layer_parallel(
components: list[PipelineComponent],
) -> None:
tasks: list[asyncio.Task[None]] = []
for coro in coroutines:
tasks.append(asyncio.create_task(coro))
for component in components:
tasks.append(asyncio.create_task(runner(component)))
await asyncio.gather(*tasks)

async def run_graph_tasks(pending_tasks: list[Awaitable[None]]) -> None:
for pending_task in pending_tasks:
await pending_task
async def run_graph_layers(
pending_layers: list[list[PipelineComponent]],
) -> None:
for layer_components in pending_layers:
await run_layer_parallel(layer_components)

graph: nx.DiGraph[str] = self._graph.copy()

Expand All @@ -130,15 +132,15 @@ async def run_graph_tasks(pending_tasks: list[Awaitable[None]]) -> None:

layers_graph: list[list[str]] = list(nx.bfs_layers(graph, root_node))

sorted_tasks: list[Awaitable[None]] = []
sorted_layers: list[list[PipelineComponent]] = []
for layer in layers_graph[1:]:
if parallel_tasks := self.__get_parallel_tasks_from(layer, runner):
sorted_tasks.append(run_parallel_tasks(parallel_tasks))
if parallel_components := self.__get_parallel_components_from(layer):
sorted_layers.append(parallel_components)

if reverse:
sorted_tasks.reverse()
sorted_layers.reverse()

return run_graph_tasks(sorted_tasks)
return run_graph_layers(sorted_layers)

def __getitem__(self, component_id: str) -> PipelineComponent:
try:
Expand Down Expand Up @@ -173,18 +175,16 @@ def __add_input(self, topic_id: str, target: str) -> None:
self._graph.add_node(topic_id)
self._graph.add_edge(topic_id, target)

def __get_parallel_tasks_from(
self,
layer: list[str],
runner: Callable[[PipelineComponent], Coroutine[Any, Any, None]],
) -> list[Coroutine[Any, Any, None]]:
def gen_parallel_tasks():
def __get_parallel_components_from(
self, layer: list[str]
) -> list[PipelineComponent]:
def gen_parallel_components() -> Iterator[PipelineComponent]:
for node_in_layer in layer:
# check if component, skip topics
if (component := self._component_index.get(node_in_layer)) is not None:
yield runner(component)
yield component

return list(gen_parallel_tasks())
return list(gen_parallel_components())

def __validate_graph(self) -> None:
if not nx.is_directed_acyclic_graph(self._graph):
Expand Down
Loading