Skip to content

[feat] Support layerwise UCM connector for hybrid linear-attention models#984

Open
wangwenxin0312 wants to merge 4 commits into
ModelEngine-Group:developfrom
wangwenxin0312:dev_hybrid_layer_pr
Open

[feat] Support layerwise UCM connector for hybrid linear-attention models#984
wangwenxin0312 wants to merge 4 commits into
ModelEngine-Group:developfrom
wangwenxin0312:dev_hybrid_layer_pr

Conversation

@wangwenxin0312

Copy link
Copy Markdown
Contributor

Purpose

Introduce UCMHybridLinearAttentionLayerWiseConnector to support layerwise KV load/store for hybrid full-attention and linear-attention layouts.

Modifications

  • Add layer-to-row mapping in the KV cache layout.
  • Load KV cache row by row and wait only when the target layer is reached. Dump KV cache at the last layer of each storage row to avoid repeated saves.
  • Add Qwen3-Next GDN patch on vLLM-Ascend to support layerwise load/save.

Test

python examples/offline_inference.py
image

req_id = request.request_id
req_block_id_groups = self.kv_cache_manager.get_block_ids(req_id)
if not req_block_id_groups:
continue

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The method _update_requests_with_invalid_blocks has complex nested control flow with multiple continue statements. Consider refactoring into smaller helper methods or using early returns to improve readability.

marked_invalid_block = True
request.num_computed_tokens = idx * self.block_size
num_affected_tokens = (
req_num_computed_tokens - request.num_computed_tokens

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This block handles the case where a request is affected but no invalid block was directly processed. The calculation request.num_computed_tokens - request.num_cached_tokens assumes num_computed_tokens >= num_cached_tokens. If this invariant doesn't hold, the total affected tokens could be negative. Consider adding an assertion.


if block_id in marked_invalid_block_ids:
continue

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The variable marked_invalid_block is set to True after processing the first invalid block, but request.num_computed_tokens is updated for each invalid block encountered. This means the final value will be based on the last invalid block, not the first. Is this the intended behavior?

maybe_save_kv_layer_to_connector,
wait_for_kv_layer_from_connector,
)
from vllm_ascend.utils import vllm_version_is

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code accesses self.kv_cache[kv_cache_index] without checking if self.kv_cache exists or if the index is valid. This could raise AttributeError or IndexError at runtime.

attn_metadata = forward_context.attn_metadata
should_save = False
if isinstance(attn_metadata, dict):
layer_attn_metadata = attn_metadata.get(self.prefix)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Imports inside the wrapper function make runtime import failures difficult to debug. Consider moving imports to module level or adding explicit error handling.


if should_save:
kv_cache_index = (
forward_context.virtual_engine if vllm_version_is("0.18.0") else 0

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After layer_attn_metadata = attn_metadata.get(self.prefix), the code should check if layer_attn_metadata is not None before the isinstance check.

Comment thread ucm/integration/vllm/ucm_connector.py Outdated
if request_id in self._failure_req_ids:
continue
try:
shard_indexs = [row_id] * len(ucm_block_ids)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic for selecting self.row_save_layer uses max() with layer_name_to_id as the key. This behavior should be explicitly documented to explain the design decision.

def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
metadata = self._get_connector_metadata()
assert isinstance(metadata, UCMConnectorMetadata)
self.load_tasks.clear()

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Error handling inconsistency: In this method, exceptions are caught differently than in wait_for_layer_load. Consider extracting a common error handling pattern.

Comment thread ucm/integration/vllm/ucm_connector.py Outdated
if len(ucm_block_ids) == 0:
continue
self.need_load = True
if self.tp_rank % self.tp_size != 0 and not self.is_mla:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The method clears load_tasks, request_data, and _failure_req_ids at the start without checking if they're empty. If called multiple times, pending tasks could be lost.

return
self._submit_request_load_tasks_for_row_once(next_row_id, metadata)

def save_kv_layer(

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using try-except (ValueError, IndexError) to handle the 'last row' case is implicit and could hide actual errors. Better to check explicitly.

logger.error(f"wait for dump kv cache failed. {type(e).__name__}: {e}")
self.dump_tasks.clear()
self.is_save = False
if self.enable_event_sync:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Catching all exceptions and only logging them could hide critical failures. If the KV cache dump fails, the system may continue with corrupted data.

self.connector = UCMCPConnector(vllm_config, role, kv_cache_config)
elif use_layerwise:
self.connector = UCMLayerWiseConnector(vllm_config, role, kv_cache_config)
elif use_hybrid_linear_attention and use_layerwise:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The connector selection order has changed. This changes behavior for configurations where both use_layerwise and use_hybrid_linear_attention flags are set. This should be documented as a potential breaking change.

blocks_to_evict.update(req_block_ids[idx:])

if is_affected:
if not marked_invalid_block:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Warning: marked_invalid_block is set inside the for-loop (line 79) but used here after the loop. The variable only reflects the last block iteration's state, not the overall marking status. If multiple blocks are invalid and the last one is already in marked_invalid_block_ids, marked_invalid_block stays False even though earlier blocks were marked. Consider tracking marked_invalid_block at the request level outside the inner loop.


@wraps(original_forward_core)
def ucm_forward_core(self, mixed_qkv, b, a, core_attn_out):
from vllm.forward_context import get_forward_context

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Suggestion: Importing modules (vllm.forward_context, vllm.v1.attention.backends.gdn_attn, etc.) inside the wrapped function body causes these imports to execute on every forward call. This is inefficient and could raise ImportError at runtime if modules aren't available. Consider moving imports to the top of the function scope or using lazy import patterns.

for request_id in request_ids:
request_meta = metadata.request_meta.get(request_id)
if request_meta is not None:
self._invalid_block_ids.update(request_meta.load_block_ids[1])

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Warning: Accessing request_meta.load_block_ids[1] without checking if load_block_ids attribute exists or has at least 2 elements could raise AttributeError or IndexError. Consider adding: if hasattr(request_meta, 'load_block_ids') and len(request_meta.load_block_ids) >= 2:

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants