-
Notifications
You must be signed in to change notification settings - Fork 86
[feat] Support layerwise UCM connector for hybrid linear-attention models #984
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,3 +1,8 @@ | ||
| from collections.abc import Iterable | ||
|
|
||
| from vllm.v1.request import Request, RequestStatus | ||
|
|
||
|
|
||
| class Scheduler: | ||
| def _mamba_block_aligned_split( | ||
| self, | ||
|
|
@@ -26,3 +31,68 @@ def _mamba_block_aligned_split( | |
| ): | ||
| num_new_tokens = last_cache_position - num_computed_tokens | ||
| return num_new_tokens | ||
|
|
||
| def _update_requests_with_invalid_blocks( | ||
| self, | ||
| requests: Iterable[Request], | ||
| invalid_block_ids: set[int], | ||
| evict_blocks: bool = True, | ||
| ) -> tuple[set[str], int, set[int]]: | ||
| affected_req_ids: set[str] = set() | ||
| total_affected_tokens = 0 | ||
| blocks_to_evict: set[int] = set() | ||
| marked_invalid_block_ids: set[int] = set() | ||
| for request in requests: | ||
| is_affected = False | ||
| marked_invalid_block = False | ||
| req_id = request.request_id | ||
| req_block_id_groups = self.kv_cache_manager.get_block_ids(req_id) | ||
| if not req_block_id_groups: | ||
| continue | ||
| # vLLM v0.18's recovery path assumes one KV group. Hybrid | ||
| # Qwen3-Next has multiple groups, and UCM's HMA dispatch always | ||
| # includes the full-attention group as group 0, so use it as the | ||
| # recovery anchor. | ||
| req_block_ids = req_block_id_groups[0] | ||
| if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS: | ||
| req_num_computed_tokens = request.num_computed_tokens | ||
| else: | ||
| req_num_computed_tokens = request.num_cached_tokens | ||
|
|
||
| req_num_computed_blocks = ( | ||
| req_num_computed_tokens + self.block_size - 1 | ||
| ) // self.block_size | ||
| for idx, block_id in zip(range(req_num_computed_blocks), req_block_ids): | ||
| if block_id not in invalid_block_ids: | ||
| continue | ||
|
|
||
| is_affected = True | ||
|
|
||
| if block_id in marked_invalid_block_ids: | ||
| continue | ||
|
|
||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The variable |
||
| marked_invalid_block_ids.add(block_id) | ||
|
|
||
| if marked_invalid_block: | ||
| continue | ||
|
|
||
| marked_invalid_block = True | ||
| request.num_computed_tokens = idx * self.block_size | ||
| num_affected_tokens = ( | ||
| req_num_computed_tokens - request.num_computed_tokens | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This block handles the case where a request is affected but no invalid block was directly processed. The calculation |
||
| ) | ||
| total_affected_tokens += num_affected_tokens | ||
| request.num_external_computed_tokens -= num_affected_tokens | ||
| if evict_blocks: | ||
| blocks_to_evict.update(req_block_ids[idx:]) | ||
|
|
||
| if is_affected: | ||
| if not marked_invalid_block: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| total_affected_tokens += ( | ||
| request.num_computed_tokens - request.num_cached_tokens | ||
| ) | ||
| request.num_computed_tokens = request.num_cached_tokens | ||
|
|
||
| affected_req_ids.add(request.request_id) | ||
|
|
||
| return affected_req_ids, total_affected_tokens, blocks_to_evict | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,80 @@ | ||
| from functools import wraps | ||
|
|
||
| from ucm.integration.vllm.patch.utils import patch_or_inject, when_imported | ||
| from ucm.logger import init_logger | ||
|
|
||
| logger = init_logger(__name__) | ||
|
|
||
|
|
||
| def _patch_empty_gdn_save(mod) -> None: | ||
| original_save = getattr(mod, "maybe_save_kv_layer_to_connector", None) | ||
| if original_save is None or getattr( | ||
| original_save, "_ucm_skip_empty_gdn_save", False | ||
| ): | ||
| return | ||
|
|
||
| @wraps(original_save) | ||
| def skip_empty_gdn_save(layer_name, kv_cache_layer): | ||
| if layer_name == "" and kv_cache_layer == []: | ||
| return | ||
| return original_save(layer_name, kv_cache_layer) | ||
|
|
||
| skip_empty_gdn_save._ucm_skip_empty_gdn_save = True | ||
| mod.maybe_save_kv_layer_to_connector = skip_empty_gdn_save | ||
|
|
||
|
|
||
| def _wrap_gdn_forward_core(mod) -> None: | ||
| target_cls = getattr(mod, "Qwen3NextGatedDeltaNet", None) | ||
| if target_cls is None: | ||
| logger.warning("Skip Qwen3Next GDN UCM patch: target class not found.") | ||
| return | ||
|
|
||
| original_forward_core = getattr(target_cls, "_forward_core", None) | ||
| if original_forward_core is None: | ||
| logger.warning("Skip Qwen3Next GDN UCM patch: _forward_core not found.") | ||
| return | ||
| if getattr(original_forward_core, "_ucm_gdn_layerwise_patched", False): | ||
| return | ||
|
|
||
| @wraps(original_forward_core) | ||
| def ucm_forward_core(self, mixed_qkv, b, a, core_attn_out): | ||
| from vllm.forward_context import get_forward_context | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 💡 Suggestion: Importing modules ( |
||
| from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata | ||
| from vllm_ascend.attention.utils import ( | ||
| maybe_save_kv_layer_to_connector, | ||
| wait_for_kv_layer_from_connector, | ||
| ) | ||
| from vllm_ascend.utils import vllm_version_is | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The code accesses |
||
|
|
||
| forward_context = get_forward_context() | ||
| attn_metadata = forward_context.attn_metadata | ||
| should_save = False | ||
| if isinstance(attn_metadata, dict): | ||
| layer_attn_metadata = attn_metadata.get(self.prefix) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Imports inside the wrapper function make runtime import failures difficult to debug. Consider moving imports to module level or adding explicit error handling. |
||
| if isinstance(layer_attn_metadata, GDNAttentionMetadata): | ||
| wait_for_kv_layer_from_connector(self.prefix) | ||
| should_save = True | ||
|
|
||
| result = original_forward_core(self, mixed_qkv, b, a, core_attn_out) | ||
|
|
||
| if should_save: | ||
| kv_cache_index = ( | ||
| forward_context.virtual_engine if vllm_version_is("0.18.0") else 0 | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. After |
||
| ) | ||
| self_kv_cache = self.kv_cache[kv_cache_index] | ||
| maybe_save_kv_layer_to_connector(self.prefix, list(self_kv_cache)) | ||
| return result | ||
|
|
||
| ucm_forward_core._ucm_gdn_layerwise_patched = True | ||
| patch_or_inject(target_cls, "_forward_core", ucm_forward_core) | ||
|
|
||
| ascend_cls = getattr(mod, "AscendQwen3Next_GatedDeltaNet", None) | ||
| if ascend_cls is not None: | ||
| patch_or_inject(ascend_cls, "_forward_core", ucm_forward_core) | ||
|
|
||
|
|
||
| @when_imported("vllm_ascend.patch.worker.patch_qwen3_next") | ||
| def patch_qwen3_next_gdn_layerwise(mod): | ||
| logger.debug(f"Patched {mod} called") | ||
| _patch_empty_gdn_save(mod) | ||
| _wrap_gdn_forward_core(mod) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The method
_update_requests_with_invalid_blockshas complex nested control flow with multiplecontinuestatements. Consider refactoring into smaller helper methods or using early returns to improve readability.