[feat] Support layerwise UCM connector for hybrid linear-attention models#984
[feat] Support layerwise UCM connector for hybrid linear-attention models#984wangwenxin0312 wants to merge 4 commits into
Conversation
| req_id = request.request_id | ||
| req_block_id_groups = self.kv_cache_manager.get_block_ids(req_id) | ||
| if not req_block_id_groups: | ||
| continue |
There was a problem hiding this comment.
The method _update_requests_with_invalid_blocks has complex nested control flow with multiple continue statements. Consider refactoring into smaller helper methods or using early returns to improve readability.
| marked_invalid_block = True | ||
| request.num_computed_tokens = idx * self.block_size | ||
| num_affected_tokens = ( | ||
| req_num_computed_tokens - request.num_computed_tokens |
There was a problem hiding this comment.
This block handles the case where a request is affected but no invalid block was directly processed. The calculation request.num_computed_tokens - request.num_cached_tokens assumes num_computed_tokens >= num_cached_tokens. If this invariant doesn't hold, the total affected tokens could be negative. Consider adding an assertion.
|
|
||
| if block_id in marked_invalid_block_ids: | ||
| continue | ||
|
|
There was a problem hiding this comment.
The variable marked_invalid_block is set to True after processing the first invalid block, but request.num_computed_tokens is updated for each invalid block encountered. This means the final value will be based on the last invalid block, not the first. Is this the intended behavior?
| maybe_save_kv_layer_to_connector, | ||
| wait_for_kv_layer_from_connector, | ||
| ) | ||
| from vllm_ascend.utils import vllm_version_is |
There was a problem hiding this comment.
The code accesses self.kv_cache[kv_cache_index] without checking if self.kv_cache exists or if the index is valid. This could raise AttributeError or IndexError at runtime.
| attn_metadata = forward_context.attn_metadata | ||
| should_save = False | ||
| if isinstance(attn_metadata, dict): | ||
| layer_attn_metadata = attn_metadata.get(self.prefix) |
There was a problem hiding this comment.
Imports inside the wrapper function make runtime import failures difficult to debug. Consider moving imports to module level or adding explicit error handling.
|
|
||
| if should_save: | ||
| kv_cache_index = ( | ||
| forward_context.virtual_engine if vllm_version_is("0.18.0") else 0 |
There was a problem hiding this comment.
After layer_attn_metadata = attn_metadata.get(self.prefix), the code should check if layer_attn_metadata is not None before the isinstance check.
| if request_id in self._failure_req_ids: | ||
| continue | ||
| try: | ||
| shard_indexs = [row_id] * len(ucm_block_ids) |
There was a problem hiding this comment.
The logic for selecting self.row_save_layer uses max() with layer_name_to_id as the key. This behavior should be explicitly documented to explain the design decision.
| def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: | ||
| metadata = self._get_connector_metadata() | ||
| assert isinstance(metadata, UCMConnectorMetadata) | ||
| self.load_tasks.clear() |
There was a problem hiding this comment.
Error handling inconsistency: In this method, exceptions are caught differently than in wait_for_layer_load. Consider extracting a common error handling pattern.
| if len(ucm_block_ids) == 0: | ||
| continue | ||
| self.need_load = True | ||
| if self.tp_rank % self.tp_size != 0 and not self.is_mla: |
There was a problem hiding this comment.
The method clears load_tasks, request_data, and _failure_req_ids at the start without checking if they're empty. If called multiple times, pending tasks could be lost.
| return | ||
| self._submit_request_load_tasks_for_row_once(next_row_id, metadata) | ||
|
|
||
| def save_kv_layer( |
There was a problem hiding this comment.
Using try-except (ValueError, IndexError) to handle the 'last row' case is implicit and could hide actual errors. Better to check explicitly.
| logger.error(f"wait for dump kv cache failed. {type(e).__name__}: {e}") | ||
| self.dump_tasks.clear() | ||
| self.is_save = False | ||
| if self.enable_event_sync: |
There was a problem hiding this comment.
Catching all exceptions and only logging them could hide critical failures. If the KV cache dump fails, the system may continue with corrupted data.
| self.connector = UCMCPConnector(vllm_config, role, kv_cache_config) | ||
| elif use_layerwise: | ||
| self.connector = UCMLayerWiseConnector(vllm_config, role, kv_cache_config) | ||
| elif use_hybrid_linear_attention and use_layerwise: |
There was a problem hiding this comment.
The connector selection order has changed. This changes behavior for configurations where both use_layerwise and use_hybrid_linear_attention flags are set. This should be documented as a potential breaking change.
| blocks_to_evict.update(req_block_ids[idx:]) | ||
|
|
||
| if is_affected: | ||
| if not marked_invalid_block: |
There was a problem hiding this comment.
marked_invalid_block is set inside the for-loop (line 79) but used here after the loop. The variable only reflects the last block iteration's state, not the overall marking status. If multiple blocks are invalid and the last one is already in marked_invalid_block_ids, marked_invalid_block stays False even though earlier blocks were marked. Consider tracking marked_invalid_block at the request level outside the inner loop.
|
|
||
| @wraps(original_forward_core) | ||
| def ucm_forward_core(self, mixed_qkv, b, a, core_attn_out): | ||
| from vllm.forward_context import get_forward_context |
There was a problem hiding this comment.
💡 Suggestion: Importing modules (vllm.forward_context, vllm.v1.attention.backends.gdn_attn, etc.) inside the wrapped function body causes these imports to execute on every forward call. This is inefficient and could raise ImportError at runtime if modules aren't available. Consider moving imports to the top of the function scope or using lazy import patterns.
| for request_id in request_ids: | ||
| request_meta = metadata.request_meta.get(request_id) | ||
| if request_meta is not None: | ||
| self._invalid_block_ids.update(request_meta.load_block_ids[1]) |
There was a problem hiding this comment.
request_meta.load_block_ids[1] without checking if load_block_ids attribute exists or has at least 2 elements could raise AttributeError or IndexError. Consider adding: if hasattr(request_meta, 'load_block_ids') and len(request_meta.load_block_ids) >= 2:
Purpose
Introduce UCMHybridLinearAttentionLayerWiseConnector to support layerwise KV load/store for hybrid full-attention and linear-attention layouts.
Modifications
Test
python examples/offline_inference.py
