From d54bf328e1f24b1334ab9f68881f8979c31f91ee Mon Sep 17 00:00:00 2001 From: MaxWang Date: Thu, 28 May 2026 17:04:20 +0800 Subject: [PATCH 1/4] hybrid layerwise connector --- ucm/integration/vllm/ucm_connector.py | 232 +++++++++++++++++++++++++- 1 file changed, 230 insertions(+), 2 deletions(-) diff --git a/ucm/integration/vllm/ucm_connector.py b/ucm/integration/vllm/ucm_connector.py index 4607eb990..7efa04f6b 100644 --- a/ucm/integration/vllm/ucm_connector.py +++ b/ucm/integration/vllm/ucm_connector.py @@ -1999,6 +1999,7 @@ def _build_layout(self, kvcaches): buffer_size_rows = [] tensor_size_lists = [] block_stride_lists = [] + self.layer_name_to_row: dict[str, int] = {} for raw_tensor in self.kv_cache_config.kv_cache_tensors: if not raw_tensor.shared_by: @@ -2014,6 +2015,7 @@ def _build_layout(self, kvcaches): ) continue + row_id = len(base_ptrs) mamba_specs = [s for s in shared_specs if isinstance(s, MambaSpec)] attn_specs = [s for s in shared_specs if isinstance(s, FullAttentionSpec)] if not mamba_specs or not attn_specs: @@ -2025,6 +2027,8 @@ def _build_layout(self, kvcaches): tensor_size_lists, block_stride_lists, ) + for layer_name in raw_tensor.shared_by: + self.layer_name_to_row[layer_name] = row_id continue if current_platform.device_type == "npu": @@ -2048,6 +2052,9 @@ def _build_layout(self, kvcaches): block_stride_lists, ) + for layer_name in raw_tensor.shared_by: + self.layer_name_to_row[layer_name] = row_id + self.base_ptrs = np.asarray(base_ptrs, dtype=np.uint64) self.buffer_sizes = np.asarray(buffer_size_rows, dtype=np.uint64) self.tensor_size_lists = np.asarray(tensor_size_lists, dtype=np.uint64) @@ -2592,6 +2599,223 @@ def _create_kv_cache_layout( ) +class UCMHybridLinearAttentionLayerWiseConnector(UCMHybridLinearAttentionConnector): + """Layerwise connector for full-attention + linear-attention hybrid layouts.""" + + def __init__( + self, + vllm_config: "VllmConfig", + role: KVConnectorRole, + kv_cache_config: "KVCacheConfig", + ): + super().__init__(vllm_config, role, kv_cache_config) + self.load_tasks: dict[int, dict[str, Task]] = defaultdict(dict) + self.dump_tasks: dict[int, list[Task]] = defaultdict(list) + self.request_data: list[tuple[str, list[bytes], np.ndarray]] = [] + self._failure_req_ids: set[str] = set() + self._submitted_load_rows: set[int] = set() + self.use_layerwise = True + self.is_save = False + self.need_load = False + logger.info("Init UCMHybridLinearAttentionLayerWiseConnector.") + + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): + super().register_kv_caches(kv_caches) + self.layer_name_to_id = self.kv_cache_layout.layer_name_to_id + self.layer_ids = sorted(set(self.layer_name_to_id.values())) + self.first_layer_id = self.layer_ids[0] + self.layer_name_to_row = getattr(self.kv_cache_layout, "layer_name_to_row", {}) + self.row_ids = sorted(set(self.layer_name_to_row.values())) + row_to_layers: dict[int, list[str]] = defaultdict(list) + for layer_name, row_id in self.layer_name_to_row.items(): + row_to_layers[row_id].append(layer_name) + self.row_save_layer = { + row_id: max( + layer_names, + key=lambda name: self.layer_name_to_id.get(name, self.first_layer_id), + ) + for row_id, layer_names in row_to_layers.items() + } + + def _submit_request_load_tasks_for_row( + self, + row_id: int, + metadata: "UCMConnectorMetadata", + ) -> None: + for request_id, ucm_block_ids, total_ptrs in self.request_data: + if request_id in self._failure_req_ids: + continue + try: + shard_indexs = [row_id] * len(ucm_block_ids) + row_ptrs = total_ptrs[row_id] + task = self.store.load_data(ucm_block_ids, shard_indexs, row_ptrs) + self.load_tasks[row_id][request_id] = task + except Exception as e: + logger.error( + f"request {request_id} submit row {row_id} load task error. " + f"{type(e).__name__}: {e}" + ) + self._invalid_block_ids.update( + metadata.request_meta[request_id].load_block_ids[1] + ) + self._failure_req_ids.add(request_id) + self._submitted_load_rows.add(row_id) + + def _submit_request_load_tasks_for_row_once( + self, + row_id: int, + metadata: "UCMConnectorMetadata", + ) -> None: + if row_id in self._submitted_load_rows: + return + self._submit_request_load_tasks_for_row(row_id, metadata) + + def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: + metadata = self._get_connector_metadata() + assert isinstance(metadata, UCMConnectorMetadata) + self.load_tasks.clear() + self.request_data.clear() + self._failure_req_ids.clear() + self._submitted_load_rows.clear() + self.need_load = False + + for request_id, request in metadata.request_meta.items(): + if len(request.load_block_ids[0]) == 0: + continue + + ucm_block_ids, vllm_block_ids = request.load_block_ids + if self._skip_null_vllm_blocks: + ucm_block_ids, vllm_block_ids = _drop_null_vllm_blocks( + ucm_block_ids, + vllm_block_ids, + f"UCM hybrid layerwise load request {request_id}", + ) + if len(ucm_block_ids) == 0: + continue + self.need_load = True + if self.tp_rank % self.tp_size != 0 and not self.is_mla: + for i, ucm_block_id in enumerate(ucm_block_ids): + ucm_block_ids[i] = self.request_hasher(ucm_block_id) + total_ptrs = self.kv_cache_layout.extract_block_addrs( + vllm_block_ids, layer_first=True + ) + self.request_data.append((request_id, ucm_block_ids, total_ptrs)) + + if self.need_load and self.row_ids: + self._submit_request_load_tasks_for_row_once(self.row_ids[0], metadata) + + def wait_for_layer_load(self, layer_name: str) -> None: + if not self._connector_metadata or not self.need_load: + return + metadata = self._get_connector_metadata() + assert isinstance(metadata, UCMConnectorMetadata) + row_id = self.layer_name_to_row.get(layer_name) + if row_id is None: + return + + self._submit_request_load_tasks_for_row_once(row_id, metadata) + row_tasks = self.load_tasks.pop(row_id, {}) + for request_id, task in row_tasks.items(): + try: + self.store.wait(task) + except Exception as e: + logger.error( + f"request {request_id} wait {layer_name} load failed. " + f"{type(e).__name__}: {e}" + ) + self._invalid_block_ids.update( + metadata.request_meta[request_id].load_block_ids[1] + ) + self._failure_req_ids.add(request_id) + + try: + next_row_id = self.row_ids[self.row_ids.index(row_id) + 1] + except (ValueError, IndexError): + return + self._submit_request_load_tasks_for_row_once(next_row_id, metadata) + + def save_kv_layer( + self, + layer_name: str, + kv_layer: torch.Tensor, + attn_metadata: "AttentionMetadata", + **kwargs, + ) -> None: + if not self._connector_metadata: + return + if self.is_mla and self.tp_rank % self.tp_size != 0: + return + + row_id = self.layer_name_to_row.get(layer_name) + if row_id is None: + return + if self.row_save_layer.get(row_id) != layer_name: + return + + metadata = self._get_connector_metadata() + assert isinstance(metadata, UCMConnectorMetadata) + total_ucm_block_ids: list[bytes] = [] + total_vllm_block_ids: list[int] = [] + for _, request in metadata.request_meta.items(): + if len(request.dump_block_ids[0]) == 0: + continue + + ucm_block_ids, vllm_block_ids = request.dump_block_ids + if self._skip_null_vllm_blocks: + ucm_block_ids, vllm_block_ids = _drop_null_vllm_blocks( + ucm_block_ids, + vllm_block_ids, + f"UCM hybrid layerwise dump row {row_id}", + ) + if len(ucm_block_ids) == 0: + continue + if self.tp_rank % self.tp_size != 0: + for i, ucm_block_id in enumerate(ucm_block_ids): + ucm_block_ids[i] = self.request_hasher(ucm_block_id) + total_ucm_block_ids.extend(ucm_block_ids) + total_vllm_block_ids.extend(vllm_block_ids) + + if not total_ucm_block_ids: + return + + self.is_save = True + self._wait_dump_tasks_for_row(row_id) + total_ptrs = self.kv_cache_layout.extract_block_addrs( + total_vllm_block_ids, layer_first=True + ) + shard_indexs = [row_id] * len(total_ucm_block_ids) + try: + row_ptrs = np.ascontiguousarray(total_ptrs[row_id]) + event_handle = self._get_dump_event_handle() + task = self.store.dump_data( + total_ucm_block_ids, shard_indexs, row_ptrs, event_handle + ) + self.dump_tasks[row_id].append(task) + except Exception as e: + logger.error( + f"submit hybrid layerwise row {row_id} dump task failed. " + f"{type(e).__name__}: {e}" + ) + + def _wait_dump_tasks_for_row(self, row_id: int) -> None: + row_tasks = self.dump_tasks.pop(row_id, []) + for task in row_tasks: + self.store.wait(task) + + def wait_for_save(self) -> None: + if not self.is_save: + return + try: + for row_id in self.row_ids: + self._wait_dump_tasks_for_row(row_id) + except Exception as e: + logger.error(f"wait for dump kv cache failed. {type(e).__name__}: {e}") + self.dump_tasks.clear() + self.is_save = False + if self.enable_event_sync: + self.device.destroy_event_handles() + + class UCMConnector(KVConnectorBase_V1, SupportsHMA): def __init__( self, @@ -2660,12 +2884,16 @@ def __init__( self.connector = UCMMockConnector(vllm_config, role, kv_cache_config) elif use_cp_parallel: self.connector = UCMCPConnector(vllm_config, role, kv_cache_config) - elif use_layerwise: - self.connector = UCMLayerWiseConnector(vllm_config, role, kv_cache_config) + elif use_hybrid_linear_attention and use_layerwise: + self.connector = UCMHybridLinearAttentionLayerWiseConnector( + vllm_config, role, kv_cache_config + ) elif use_hybrid_linear_attention: self.connector = UCMHybridLinearAttentionConnector( vllm_config, role, kv_cache_config ) + elif use_layerwise: + self.connector = UCMLayerWiseConnector(vllm_config, role, kv_cache_config) # elif use_hma: # self.connector = UCMHMAConnector(vllm_config, role, kv_cache_config) else: From 551b9a1135317ba0d0079e5b44499452ccbee06e Mon Sep 17 00:00:00 2001 From: MaxWang Date: Thu, 28 May 2026 17:04:57 +0800 Subject: [PATCH 2/4] patch update --- .../v0180/vllm/pc/v1/core/sched/scheduler.py | 70 ++++++++++++++++ .../vllm/patch/v0180/vllm/pc_patch.py | 5 ++ .../vllm_ascend/pc/worker/patch_qwen3_next.py | 80 +++++++++++++++++++ .../v0180/vllm_ascend/pc_ascend_patch.py | 1 + 4 files changed, 156 insertions(+) create mode 100644 ucm/integration/vllm/patch/v0180/vllm_ascend/pc/worker/patch_qwen3_next.py diff --git a/ucm/integration/vllm/patch/v0180/vllm/pc/v1/core/sched/scheduler.py b/ucm/integration/vllm/patch/v0180/vllm/pc/v1/core/sched/scheduler.py index a010a300f..3769067fe 100644 --- a/ucm/integration/vllm/patch/v0180/vllm/pc/v1/core/sched/scheduler.py +++ b/ucm/integration/vllm/patch/v0180/vllm/pc/v1/core/sched/scheduler.py @@ -1,3 +1,8 @@ +from collections.abc import Iterable + +from vllm.v1.request import Request, RequestStatus + + class Scheduler: def _mamba_block_aligned_split( self, @@ -26,3 +31,68 @@ def _mamba_block_aligned_split( ): num_new_tokens = last_cache_position - num_computed_tokens return num_new_tokens + + def _update_requests_with_invalid_blocks( + self, + requests: Iterable[Request], + invalid_block_ids: set[int], + evict_blocks: bool = True, + ) -> tuple[set[str], int, set[int]]: + affected_req_ids: set[str] = set() + total_affected_tokens = 0 + blocks_to_evict: set[int] = set() + marked_invalid_block_ids: set[int] = set() + for request in requests: + is_affected = False + marked_invalid_block = False + req_id = request.request_id + req_block_id_groups = self.kv_cache_manager.get_block_ids(req_id) + if not req_block_id_groups: + continue + # vLLM v0.18's recovery path assumes one KV group. Hybrid + # Qwen3-Next has multiple groups, and UCM's HMA dispatch always + # includes the full-attention group as group 0, so use it as the + # recovery anchor. + req_block_ids = req_block_id_groups[0] + if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS: + req_num_computed_tokens = request.num_computed_tokens + else: + req_num_computed_tokens = request.num_cached_tokens + + req_num_computed_blocks = ( + req_num_computed_tokens + self.block_size - 1 + ) // self.block_size + for idx, block_id in zip(range(req_num_computed_blocks), req_block_ids): + if block_id not in invalid_block_ids: + continue + + is_affected = True + + if block_id in marked_invalid_block_ids: + continue + + marked_invalid_block_ids.add(block_id) + + if marked_invalid_block: + continue + + marked_invalid_block = True + request.num_computed_tokens = idx * self.block_size + num_affected_tokens = ( + req_num_computed_tokens - request.num_computed_tokens + ) + total_affected_tokens += num_affected_tokens + request.num_external_computed_tokens -= num_affected_tokens + if evict_blocks: + blocks_to_evict.update(req_block_ids[idx:]) + + if is_affected: + if not marked_invalid_block: + total_affected_tokens += ( + request.num_computed_tokens - request.num_cached_tokens + ) + request.num_computed_tokens = request.num_cached_tokens + + affected_req_ids.add(request.request_id) + + return affected_req_ids, total_affected_tokens, blocks_to_evict diff --git a/ucm/integration/vllm/patch/v0180/vllm/pc_patch.py b/ucm/integration/vllm/patch/v0180/vllm/pc_patch.py index 3fbd4b641..1251e828e 100644 --- a/ucm/integration/vllm/patch/v0180/vllm/pc_patch.py +++ b/ucm/integration/vllm/patch/v0180/vllm/pc_patch.py @@ -48,6 +48,11 @@ def patch_core_sched_scheduler(mod): "_mamba_block_aligned_split", scheduler.Scheduler._mamba_block_aligned_split, ) + patch_or_inject( + mod.Scheduler, + "_update_requests_with_invalid_blocks", + scheduler.Scheduler._update_requests_with_invalid_blocks, + ) @when_imported("vllm.v1.worker.gpu_model_runner") diff --git a/ucm/integration/vllm/patch/v0180/vllm_ascend/pc/worker/patch_qwen3_next.py b/ucm/integration/vllm/patch/v0180/vllm_ascend/pc/worker/patch_qwen3_next.py new file mode 100644 index 000000000..3a51485f2 --- /dev/null +++ b/ucm/integration/vllm/patch/v0180/vllm_ascend/pc/worker/patch_qwen3_next.py @@ -0,0 +1,80 @@ +from functools import wraps + +from ucm.integration.vllm.patch.utils import patch_or_inject, when_imported +from ucm.logger import init_logger + +logger = init_logger(__name__) + + +def _patch_empty_gdn_save(mod) -> None: + original_save = getattr(mod, "maybe_save_kv_layer_to_connector", None) + if original_save is None or getattr( + original_save, "_ucm_skip_empty_gdn_save", False + ): + return + + @wraps(original_save) + def skip_empty_gdn_save(layer_name, kv_cache_layer): + if layer_name == "" and kv_cache_layer == []: + return + return original_save(layer_name, kv_cache_layer) + + skip_empty_gdn_save._ucm_skip_empty_gdn_save = True + mod.maybe_save_kv_layer_to_connector = skip_empty_gdn_save + + +def _wrap_gdn_forward_core(mod) -> None: + target_cls = getattr(mod, "Qwen3NextGatedDeltaNet", None) + if target_cls is None: + logger.warning("Skip Qwen3Next GDN UCM patch: target class not found.") + return + + original_forward_core = getattr(target_cls, "_forward_core", None) + if original_forward_core is None: + logger.warning("Skip Qwen3Next GDN UCM patch: _forward_core not found.") + return + if getattr(original_forward_core, "_ucm_gdn_layerwise_patched", False): + return + + @wraps(original_forward_core) + def ucm_forward_core(self, mixed_qkv, b, a, core_attn_out): + from vllm.forward_context import get_forward_context + from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata + from vllm_ascend.attention.utils import ( + maybe_save_kv_layer_to_connector, + wait_for_kv_layer_from_connector, + ) + from vllm_ascend.utils import vllm_version_is + + forward_context = get_forward_context() + attn_metadata = forward_context.attn_metadata + should_save = False + if isinstance(attn_metadata, dict): + layer_attn_metadata = attn_metadata.get(self.prefix) + if isinstance(layer_attn_metadata, GDNAttentionMetadata): + wait_for_kv_layer_from_connector(self.prefix) + should_save = True + + result = original_forward_core(self, mixed_qkv, b, a, core_attn_out) + + if should_save: + kv_cache_index = ( + forward_context.virtual_engine if vllm_version_is("0.18.0") else 0 + ) + self_kv_cache = self.kv_cache[kv_cache_index] + maybe_save_kv_layer_to_connector(self.prefix, list(self_kv_cache)) + return result + + ucm_forward_core._ucm_gdn_layerwise_patched = True + patch_or_inject(target_cls, "_forward_core", ucm_forward_core) + + ascend_cls = getattr(mod, "AscendQwen3Next_GatedDeltaNet", None) + if ascend_cls is not None: + patch_or_inject(ascend_cls, "_forward_core", ucm_forward_core) + + +@when_imported("vllm_ascend.patch.worker.patch_qwen3_next") +def patch_qwen3_next_gdn_layerwise(mod): + logger.debug(f"Patched {mod} called") + _patch_empty_gdn_save(mod) + _wrap_gdn_forward_core(mod) diff --git a/ucm/integration/vllm/patch/v0180/vllm_ascend/pc_ascend_patch.py b/ucm/integration/vllm/patch/v0180/vllm_ascend/pc_ascend_patch.py index f8ea222b3..6ef8d09a7 100644 --- a/ucm/integration/vllm/patch/v0180/vllm_ascend/pc_ascend_patch.py +++ b/ucm/integration/vllm/patch/v0180/vllm_ascend/pc_ascend_patch.py @@ -1,5 +1,6 @@ import inspect +import ucm.integration.vllm.patch.v0180.vllm_ascend.pc.worker.patch_qwen3_next # noqa: F401 import ucm.integration.vllm.patch.v0180.vllm_ascend.ucm_connector_patch # noqa: F401 from ucm.integration.vllm.patch.utils import ( patch_or_inject, From 13bff98ece8bbb33d19079fdf57f0ccb82718e85 Mon Sep 17 00:00:00 2001 From: MaxWang Date: Sat, 30 May 2026 17:40:34 +0800 Subject: [PATCH 3/4] prefetch --- ucm/integration/vllm/ucm_connector.py | 189 ++++++++++++++++++-------- 1 file changed, 131 insertions(+), 58 deletions(-) diff --git a/ucm/integration/vllm/ucm_connector.py b/ucm/integration/vllm/ucm_connector.py index 7efa04f6b..0ee126318 100644 --- a/ucm/integration/vllm/ucm_connector.py +++ b/ucm/integration/vllm/ucm_connector.py @@ -225,6 +225,15 @@ def extract_block_addrs( + self.base_ptrs[None, :, :] ) # (num_blocks, n_layers, n_ptrs) + def extract_block_addrs_for_row( + self, vllm_block_ids: List[int], row_id: int + ) -> np.ndarray: + vllm_block_ids_np = np.asarray(vllm_block_ids, dtype=np.uint64) + return ( + vllm_block_ids_np[:, None] * self.block_stride_lists[row_id][None, :] + + self.base_ptrs[row_id][None, :] + ) + @property def tensor_size_list(self) -> list[int]: return ( @@ -2609,15 +2618,32 @@ def __init__( kv_cache_config: "KVCacheConfig", ): super().__init__(vllm_config, role, kv_cache_config) - self.load_tasks: dict[int, dict[str, Task]] = defaultdict(dict) + self.load_tasks: dict[int, list[tuple[Task, tuple[str, ...]]]] = defaultdict( + list + ) self.dump_tasks: dict[int, list[Task]] = defaultdict(list) - self.request_data: list[tuple[str, list[bytes], np.ndarray]] = [] + self.request_data: list[tuple[str, list[bytes], list[int]]] = [] self._failure_req_ids: set[str] = set() self._submitted_load_rows: set[int] = set() + self._dump_transfer_data: tuple[list[bytes], list[int]] | None = None + prefetch_rows_config = self.launch_config.get( + "hybrid_layerwise_prefetch_rows", 2 + ) + try: + self._load_prefetch_rows = max(1, int(prefetch_rows_config)) + except (TypeError, ValueError): + logger.warning( + "Invalid hybrid_layerwise_prefetch_rows=%r; fallback to 2.", + prefetch_rows_config, + ) + self._load_prefetch_rows = 2 self.use_layerwise = True self.is_save = False self.need_load = False - logger.info("Init UCMHybridLinearAttentionLayerWiseConnector.") + logger.info( + "Init UCMHybridLinearAttentionLayerWiseConnector " + f"with prefetch_rows={self._load_prefetch_rows}." + ) def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): super().register_kv_caches(kv_caches) @@ -2636,29 +2662,62 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): ) for row_id, layer_names in row_to_layers.items() } + logger.info( + "Hybrid layerwise layout: " + f"rows={len(self.row_ids)}, row_ids={_short_list(self.row_ids)}, " + f"row_save_layers={len(self.row_save_layer)}" + ) + + def _rank_scoped_ucm_block_ids(self, ucm_block_ids: list[bytes]) -> list[bytes]: + if self.tp_rank % self.tp_size == 0 or self.is_mla: + return ucm_block_ids + return [self.request_hasher(ucm_block_id) for ucm_block_id in ucm_block_ids] + + def _mark_load_failed( + self, + metadata: "UCMConnectorMetadata", + request_ids: tuple[str, ...], + ) -> None: + for request_id in request_ids: + request_meta = metadata.request_meta.get(request_id) + if request_meta is not None: + self._invalid_block_ids.update(request_meta.load_block_ids[1]) + self._failure_req_ids.add(request_id) def _submit_request_load_tasks_for_row( self, row_id: int, metadata: "UCMConnectorMetadata", ) -> None: - for request_id, ucm_block_ids, total_ptrs in self.request_data: + total_ucm_block_ids: list[bytes] = [] + row_ptrs_parts: list[np.ndarray] = [] + request_ids: list[str] = [] + for request_id, ucm_block_ids, vllm_block_ids in self.request_data: if request_id in self._failure_req_ids: continue - try: - shard_indexs = [row_id] * len(ucm_block_ids) - row_ptrs = total_ptrs[row_id] - task = self.store.load_data(ucm_block_ids, shard_indexs, row_ptrs) - self.load_tasks[row_id][request_id] = task - except Exception as e: - logger.error( - f"request {request_id} submit row {row_id} load task error. " - f"{type(e).__name__}: {e}" - ) - self._invalid_block_ids.update( - metadata.request_meta[request_id].load_block_ids[1] - ) - self._failure_req_ids.add(request_id) + total_ucm_block_ids.extend(ucm_block_ids) + row_ptrs_parts.append( + self.kv_cache_layout.extract_block_addrs_for_row(vllm_block_ids, row_id) + ) + request_ids.append(request_id) + if not total_ucm_block_ids: + self._submitted_load_rows.add(row_id) + return + try: + row_ptrs = ( + np.ascontiguousarray(row_ptrs_parts[0]) + if len(row_ptrs_parts) == 1 + else np.ascontiguousarray(np.concatenate(row_ptrs_parts, axis=0)) + ) + shard_indexs = [row_id] * len(total_ucm_block_ids) + task = self.store.load_data(total_ucm_block_ids, shard_indexs, row_ptrs) + self.load_tasks[row_id].append((task, tuple(request_ids))) + except Exception as e: + logger.error( + f"submit hybrid layerwise row {row_id} load task error. " + f"{type(e).__name__}: {e}" + ) + self._mark_load_failed(metadata, tuple(request_ids)) self._submitted_load_rows.add(row_id) def _submit_request_load_tasks_for_row_once( @@ -2670,6 +2729,15 @@ def _submit_request_load_tasks_for_row_once( return self._submit_request_load_tasks_for_row(row_id, metadata) + def _submit_prefetch_rows( + self, + start_idx: int, + metadata: "UCMConnectorMetadata", + ) -> None: + end_idx = min(start_idx + self._load_prefetch_rows, len(self.row_ids)) + for idx in range(start_idx, end_idx): + self._submit_request_load_tasks_for_row_once(self.row_ids[idx], metadata) + def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: metadata = self._get_connector_metadata() assert isinstance(metadata, UCMConnectorMetadata) @@ -2677,6 +2745,7 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: self.request_data.clear() self._failure_req_ids.clear() self._submitted_load_rows.clear() + self._dump_transfer_data = None self.need_load = False for request_id, request in metadata.request_meta.items(): @@ -2693,16 +2762,11 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: if len(ucm_block_ids) == 0: continue self.need_load = True - if self.tp_rank % self.tp_size != 0 and not self.is_mla: - for i, ucm_block_id in enumerate(ucm_block_ids): - ucm_block_ids[i] = self.request_hasher(ucm_block_id) - total_ptrs = self.kv_cache_layout.extract_block_addrs( - vllm_block_ids, layer_first=True - ) - self.request_data.append((request_id, ucm_block_ids, total_ptrs)) + ucm_block_ids = self._rank_scoped_ucm_block_ids(ucm_block_ids) + self.request_data.append((request_id, ucm_block_ids, vllm_block_ids)) if self.need_load and self.row_ids: - self._submit_request_load_tasks_for_row_once(self.row_ids[0], metadata) + self._submit_prefetch_rows(0, metadata) def wait_for_layer_load(self, layer_name: str) -> None: if not self._connector_metadata or not self.need_load: @@ -2714,25 +2778,25 @@ def wait_for_layer_load(self, layer_name: str) -> None: return self._submit_request_load_tasks_for_row_once(row_id, metadata) - row_tasks = self.load_tasks.pop(row_id, {}) - for request_id, task in row_tasks.items(): + row_tasks = self.load_tasks.pop(row_id, []) + for task, request_ids in row_tasks: try: self.store.wait(task) except Exception as e: logger.error( - f"request {request_id} wait {layer_name} load failed. " + f"requests {list(request_ids)} wait {layer_name} load failed. " f"{type(e).__name__}: {e}" ) - self._invalid_block_ids.update( - metadata.request_meta[request_id].load_block_ids[1] - ) - self._failure_req_ids.add(request_id) + self._mark_load_failed(metadata, request_ids) try: - next_row_id = self.row_ids[self.row_ids.index(row_id) + 1] + next_row_idx = self.row_ids.index(row_id) + self._load_prefetch_rows except (ValueError, IndexError): return - self._submit_request_load_tasks_for_row_once(next_row_id, metadata) + if next_row_idx < len(self.row_ids): + self._submit_request_load_tasks_for_row_once( + self.row_ids[next_row_idx], metadata + ) def save_kv_layer( self, @@ -2754,38 +2818,21 @@ def save_kv_layer( metadata = self._get_connector_metadata() assert isinstance(metadata, UCMConnectorMetadata) - total_ucm_block_ids: list[bytes] = [] - total_vllm_block_ids: list[int] = [] - for _, request in metadata.request_meta.items(): - if len(request.dump_block_ids[0]) == 0: - continue - - ucm_block_ids, vllm_block_ids = request.dump_block_ids - if self._skip_null_vllm_blocks: - ucm_block_ids, vllm_block_ids = _drop_null_vllm_blocks( - ucm_block_ids, - vllm_block_ids, - f"UCM hybrid layerwise dump row {row_id}", - ) - if len(ucm_block_ids) == 0: - continue - if self.tp_rank % self.tp_size != 0: - for i, ucm_block_id in enumerate(ucm_block_ids): - ucm_block_ids[i] = self.request_hasher(ucm_block_id) - total_ucm_block_ids.extend(ucm_block_ids) - total_vllm_block_ids.extend(vllm_block_ids) + if self._dump_transfer_data is None: + self._dump_transfer_data = self._build_dump_transfer_data(metadata, row_id) + total_ucm_block_ids, total_vllm_block_ids = self._dump_transfer_data if not total_ucm_block_ids: return self.is_save = True self._wait_dump_tasks_for_row(row_id) - total_ptrs = self.kv_cache_layout.extract_block_addrs( - total_vllm_block_ids, layer_first=True + row_ptrs = self.kv_cache_layout.extract_block_addrs_for_row( + total_vllm_block_ids, row_id ) shard_indexs = [row_id] * len(total_ucm_block_ids) try: - row_ptrs = np.ascontiguousarray(total_ptrs[row_id]) + row_ptrs = np.ascontiguousarray(row_ptrs) event_handle = self._get_dump_event_handle() task = self.store.dump_data( total_ucm_block_ids, shard_indexs, row_ptrs, event_handle @@ -2797,6 +2844,31 @@ def save_kv_layer( f"{type(e).__name__}: {e}" ) + def _build_dump_transfer_data( + self, + metadata: "UCMConnectorMetadata", + row_id: int, + ) -> tuple[list[bytes], list[int]]: + total_ucm_block_ids: list[bytes] = [] + total_vllm_block_ids: list[int] = [] + for _, request in metadata.request_meta.items(): + if len(request.dump_block_ids[0]) == 0: + continue + + ucm_block_ids, vllm_block_ids = request.dump_block_ids + if self._skip_null_vllm_blocks: + ucm_block_ids, vllm_block_ids = _drop_null_vllm_blocks( + ucm_block_ids, + vllm_block_ids, + f"UCM hybrid layerwise dump row {row_id}", + ) + if len(ucm_block_ids) == 0: + continue + ucm_block_ids = self._rank_scoped_ucm_block_ids(ucm_block_ids) + total_ucm_block_ids.extend(ucm_block_ids) + total_vllm_block_ids.extend(vllm_block_ids) + return total_ucm_block_ids, total_vllm_block_ids + def _wait_dump_tasks_for_row(self, row_id: int) -> None: row_tasks = self.dump_tasks.pop(row_id, []) for task in row_tasks: @@ -2811,6 +2883,7 @@ def wait_for_save(self) -> None: except Exception as e: logger.error(f"wait for dump kv cache failed. {type(e).__name__}: {e}") self.dump_tasks.clear() + self._dump_transfer_data = None self.is_save = False if self.enable_event_sync: self.device.destroy_event_handles() From 33ce1efcc057ba6d6d0cc3345b413787cc35dd46 Mon Sep 17 00:00:00 2001 From: MaxWang Date: Mon, 25 May 2026 09:45:45 +0800 Subject: [PATCH 4/4] mtp fix --- ucm/integration/vllm/ucm_connector.py | 45 ++++++++++++++++++++++++--- 1 file changed, 41 insertions(+), 4 deletions(-) diff --git a/ucm/integration/vllm/ucm_connector.py b/ucm/integration/vllm/ucm_connector.py index 0ee126318..dc007d715 100644 --- a/ucm/integration/vllm/ucm_connector.py +++ b/ucm/integration/vllm/ucm_connector.py @@ -1903,6 +1903,41 @@ class HybridLinearAttentionLayout(HMAKVCacheLayout): with a full-page stride. """ + def _finalize_layout_arrays( + self, + base_ptrs: list[list[int]], + buffer_size_rows: list[list[int]], + tensor_size_lists: list[list[int]], + block_stride_lists: list[list[int]], + ) -> None: + # MTP can add attention-only raw tensors next to hybrid attention+Mamba + # tensors. Those rows naturally have a different number of physical + # slices, so keep the UCM schema flattened instead of forcing a + # rectangular layer-by-slice matrix. + self.base_ptrs = np.asarray( + [ptr for row in base_ptrs for ptr in row], dtype=np.uint64 + ) + self.buffer_sizes = np.asarray( + [size for row in buffer_size_rows for size in row], dtype=np.uint64 + ) + self.tensor_size_lists = np.asarray( + [size for row in tensor_size_lists for size in row], dtype=np.uint64 + ) + self.block_stride_lists = np.asarray( + [stride for row in block_stride_lists for stride in row], dtype=np.uint64 + ) + + def extract_block_addrs( + self, vllm_block_ids: List[int], layer_first: bool = False + ) -> np.ndarray: + if layer_first: + raise ValueError("layer_first is not supported for flattened hybrid layout") + vllm_block_ids_np = np.asarray(vllm_block_ids, dtype=np.uint64) + return ( + vllm_block_ids_np[:, None] * self.block_stride_lists[None, :] + + self.base_ptrs[None, :] + ) + def _collect_shared_tensor_info( self, raw_tensor, @@ -2064,10 +2099,12 @@ def _build_layout(self, kvcaches): for layer_name in raw_tensor.shared_by: self.layer_name_to_row[layer_name] = row_id - self.base_ptrs = np.asarray(base_ptrs, dtype=np.uint64) - self.buffer_sizes = np.asarray(buffer_size_rows, dtype=np.uint64) - self.tensor_size_lists = np.asarray(tensor_size_lists, dtype=np.uint64) - self.block_stride_lists = np.asarray(block_stride_lists, dtype=np.uint64) + self._finalize_layout_arrays( + base_ptrs, + buffer_size_rows, + tensor_size_lists, + block_stride_lists, + ) class UCMHMAConnector(UCMDirectConnector, SupportsHMA):