Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
from collections.abc import Iterable

from vllm.v1.request import Request, RequestStatus


class Scheduler:
def _mamba_block_aligned_split(
self,
Expand Down Expand Up @@ -26,3 +31,68 @@ def _mamba_block_aligned_split(
):
num_new_tokens = last_cache_position - num_computed_tokens
return num_new_tokens

def _update_requests_with_invalid_blocks(
self,
requests: Iterable[Request],
invalid_block_ids: set[int],
evict_blocks: bool = True,
) -> tuple[set[str], int, set[int]]:
affected_req_ids: set[str] = set()
total_affected_tokens = 0
blocks_to_evict: set[int] = set()
marked_invalid_block_ids: set[int] = set()
for request in requests:
is_affected = False
marked_invalid_block = False
req_id = request.request_id
req_block_id_groups = self.kv_cache_manager.get_block_ids(req_id)
if not req_block_id_groups:
continue

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The method _update_requests_with_invalid_blocks has complex nested control flow with multiple continue statements. Consider refactoring into smaller helper methods or using early returns to improve readability.

# vLLM v0.18's recovery path assumes one KV group. Hybrid
# Qwen3-Next has multiple groups, and UCM's HMA dispatch always
# includes the full-attention group as group 0, so use it as the
# recovery anchor.
req_block_ids = req_block_id_groups[0]
if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
req_num_computed_tokens = request.num_computed_tokens
else:
req_num_computed_tokens = request.num_cached_tokens

req_num_computed_blocks = (
req_num_computed_tokens + self.block_size - 1
) // self.block_size
for idx, block_id in zip(range(req_num_computed_blocks), req_block_ids):
if block_id not in invalid_block_ids:
continue

is_affected = True

if block_id in marked_invalid_block_ids:
continue

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The variable marked_invalid_block is set to True after processing the first invalid block, but request.num_computed_tokens is updated for each invalid block encountered. This means the final value will be based on the last invalid block, not the first. Is this the intended behavior?

marked_invalid_block_ids.add(block_id)

if marked_invalid_block:
continue

marked_invalid_block = True
request.num_computed_tokens = idx * self.block_size
num_affected_tokens = (
req_num_computed_tokens - request.num_computed_tokens

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This block handles the case where a request is affected but no invalid block was directly processed. The calculation request.num_computed_tokens - request.num_cached_tokens assumes num_computed_tokens >= num_cached_tokens. If this invariant doesn't hold, the total affected tokens could be negative. Consider adding an assertion.

)
total_affected_tokens += num_affected_tokens
request.num_external_computed_tokens -= num_affected_tokens
if evict_blocks:
blocks_to_evict.update(req_block_ids[idx:])

if is_affected:
if not marked_invalid_block:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Warning: marked_invalid_block is set inside the for-loop (line 79) but used here after the loop. The variable only reflects the last block iteration's state, not the overall marking status. If multiple blocks are invalid and the last one is already in marked_invalid_block_ids, marked_invalid_block stays False even though earlier blocks were marked. Consider tracking marked_invalid_block at the request level outside the inner loop.

total_affected_tokens += (
request.num_computed_tokens - request.num_cached_tokens
)
request.num_computed_tokens = request.num_cached_tokens

affected_req_ids.add(request.request_id)

return affected_req_ids, total_affected_tokens, blocks_to_evict
5 changes: 5 additions & 0 deletions ucm/integration/vllm/patch/v0180/vllm/pc_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@ def patch_core_sched_scheduler(mod):
"_mamba_block_aligned_split",
scheduler.Scheduler._mamba_block_aligned_split,
)
patch_or_inject(
mod.Scheduler,
"_update_requests_with_invalid_blocks",
scheduler.Scheduler._update_requests_with_invalid_blocks,
)


@when_imported("vllm.v1.worker.gpu_model_runner")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from functools import wraps

from ucm.integration.vllm.patch.utils import patch_or_inject, when_imported
from ucm.logger import init_logger

logger = init_logger(__name__)


def _patch_empty_gdn_save(mod) -> None:
original_save = getattr(mod, "maybe_save_kv_layer_to_connector", None)
if original_save is None or getattr(
original_save, "_ucm_skip_empty_gdn_save", False
):
return

@wraps(original_save)
def skip_empty_gdn_save(layer_name, kv_cache_layer):
if layer_name == "" and kv_cache_layer == []:
return
return original_save(layer_name, kv_cache_layer)

skip_empty_gdn_save._ucm_skip_empty_gdn_save = True
mod.maybe_save_kv_layer_to_connector = skip_empty_gdn_save


def _wrap_gdn_forward_core(mod) -> None:
target_cls = getattr(mod, "Qwen3NextGatedDeltaNet", None)
if target_cls is None:
logger.warning("Skip Qwen3Next GDN UCM patch: target class not found.")
return

original_forward_core = getattr(target_cls, "_forward_core", None)
if original_forward_core is None:
logger.warning("Skip Qwen3Next GDN UCM patch: _forward_core not found.")
return
if getattr(original_forward_core, "_ucm_gdn_layerwise_patched", False):
return

@wraps(original_forward_core)
def ucm_forward_core(self, mixed_qkv, b, a, core_attn_out):
from vllm.forward_context import get_forward_context

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Suggestion: Importing modules (vllm.forward_context, vllm.v1.attention.backends.gdn_attn, etc.) inside the wrapped function body causes these imports to execute on every forward call. This is inefficient and could raise ImportError at runtime if modules aren't available. Consider moving imports to the top of the function scope or using lazy import patterns.

from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata
from vllm_ascend.attention.utils import (
maybe_save_kv_layer_to_connector,
wait_for_kv_layer_from_connector,
)
from vllm_ascend.utils import vllm_version_is

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code accesses self.kv_cache[kv_cache_index] without checking if self.kv_cache exists or if the index is valid. This could raise AttributeError or IndexError at runtime.


forward_context = get_forward_context()
attn_metadata = forward_context.attn_metadata
should_save = False
if isinstance(attn_metadata, dict):
layer_attn_metadata = attn_metadata.get(self.prefix)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Imports inside the wrapper function make runtime import failures difficult to debug. Consider moving imports to module level or adding explicit error handling.

if isinstance(layer_attn_metadata, GDNAttentionMetadata):
wait_for_kv_layer_from_connector(self.prefix)
should_save = True

result = original_forward_core(self, mixed_qkv, b, a, core_attn_out)

if should_save:
kv_cache_index = (
forward_context.virtual_engine if vllm_version_is("0.18.0") else 0

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After layer_attn_metadata = attn_metadata.get(self.prefix), the code should check if layer_attn_metadata is not None before the isinstance check.

)
self_kv_cache = self.kv_cache[kv_cache_index]
maybe_save_kv_layer_to_connector(self.prefix, list(self_kv_cache))
return result

ucm_forward_core._ucm_gdn_layerwise_patched = True
patch_or_inject(target_cls, "_forward_core", ucm_forward_core)

ascend_cls = getattr(mod, "AscendQwen3Next_GatedDeltaNet", None)
if ascend_cls is not None:
patch_or_inject(ascend_cls, "_forward_core", ucm_forward_core)


@when_imported("vllm_ascend.patch.worker.patch_qwen3_next")
def patch_qwen3_next_gdn_layerwise(mod):
logger.debug(f"Patched {mod} called")
_patch_empty_gdn_save(mod)
_wrap_gdn_forward_core(mod)
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import inspect

import ucm.integration.vllm.patch.v0180.vllm_ascend.pc.worker.patch_qwen3_next # noqa: F401
import ucm.integration.vllm.patch.v0180.vllm_ascend.ucm_connector_patch # noqa: F401
from ucm.integration.vllm.patch.utils import (
patch_or_inject,
Expand Down
Loading
Loading