diff --git a/ucm/integration/vllm/ucm_connector.py b/ucm/integration/vllm/ucm_connector.py index 4607eb990..a16f1dc3a 100644 --- a/ucm/integration/vllm/ucm_connector.py +++ b/ucm/integration/vllm/ucm_connector.py @@ -1250,72 +1250,127 @@ def get_num_new_matched_tokens( return expect_hit_block_num * self.block_size, False -class UCMLiteConnector(UCMDirectConnector): +class UCMLiteConnector(KVConnectorBase_V1): def __init__( self, vllm_config, role, kv_cache_config: Optional["KVCacheConfig"] = None, ): - ucm_config = Config(vllm_config.kv_transfer_config) - launch_config = ucm_config.get_config() - enable_record_traces = launch_config.get("enable_record_traces", False) - persist_token_threshold = launch_config.get("persist_token_threshold", 0) - vllm_config.kv_transfer_config.kv_connector_extra_config = { - "ucm_connectors": [ - { - "ucm_connector_name": "UcmPipelineStore", - "ucm_connector_config": { - "store_pipeline": "Fake", - "share_buffer_enable": True, - "buffer_number": 244032232, - }, - } - ], - "enable_record_traces": enable_record_traces, - "persist_token_threshold": persist_token_threshold, - "use_lite": True, - } - super().__init__(vllm_config, role, kv_cache_config) + self.block_size = vllm_config.cache_config.block_size + self.hash_block_size = self.block_size + self.requests_meta: dict[str, RequestMeta] = {} self.total_block_nums = 0 - self.total_hit_block_nums = 0 + + self.request_hasher = RequestHasher(vllm_config, 0) + self._seed = self.request_hasher("UCM_HASH_SEED") + + super().__init__(vllm_config, role, kv_cache_config) + logger.info("Init UCMLiteConnector.") def get_num_new_matched_tokens(self, request, num_computed_tokens): - super().get_num_new_matched_tokens(request, num_computed_tokens) - - external_hit_blocks = 0 req_blocks_num = len(request.all_token_ids) // self.hash_block_size if req_blocks_num < 1: return 0, False - self.total_block_nums += req_blocks_num - if request.request_id in self.requests_meta: - request_meta = self.requests_meta[request.request_id] - external_hit_blocks = ( - request_meta.total_hit_block_num - request_meta.hbm_hit_block_num + if request.request_id not in self.requests_meta: + hash_start = time.perf_counter() + ucm_block_ids = self.generate_hash( + self.hash_block_size, request.all_token_ids, self._seed + ) + hash_end = time.perf_counter() + hash_time_ms = (hash_end - hash_start) * 1000.0 + + print_start = time.perf_counter() + hex_ucm_block_ids = [b.hex() for b in ucm_block_ids] + logger.info( + f"timestamp: {time.perf_counter()}, " + f"request_id: {request.request_id}, " + f"input_length: {request.num_tokens}, " + f"output_length: {request.max_tokens}, " + f"ucm_block_ids: {hex_ucm_block_ids}" + ) + print_time_ms = (time.perf_counter() - print_start) * 1000.0 + logger.info( + f"request_id: {request.request_id}, " + f"hash_time_ms: {hash_time_ms:.3f}, " + f"print_time_ms: {print_time_ms:.3f}" ) - need_dump_blks = request_meta.ucm_block_ids[ - request_meta.total_hit_block_num : - ] - shard_indexs = [0] * len(need_dump_blks) - total_ptrs = [[0]] * len(need_dump_blks) - try: - task = self.store.dump_data(need_dump_blks, shard_indexs, total_ptrs) - self.store.wait(task) - except Exception as e: - logger.error( - f"request {request.request_id} wait dump task error. {type(e).__name__}: {e}" - ) - self.requests_meta[request.request_id] = RequestMeta() - self.total_hit_block_nums += external_hit_blocks + # store minimal RequestMeta for scheduler bookkeeping + self.requests_meta[request.request_id] = RequestMeta( + ucm_block_ids=ucm_block_ids, + hbm_hit_block_num=0, + total_hit_block_num=0, + num_token_ids=len(request.all_token_ids), + token_processed=0, + ) + + self.total_block_nums += req_blocks_num - logger.info( - f"req external hit rate: {(external_hit_blocks / req_blocks_num):.2f}, " - f"total external hit rate: {(self.total_hit_block_nums / self.total_block_nums):.2f}" - ) return 0, False + def update_state_after_alloc( + self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int + ): + pass + + def build_connector_meta( + self, scheduler_output: SchedulerOutput + ) -> KVConnectorMetadata: + for request_id in scheduler_output.finished_req_ids: + self.requests_meta.pop(request_id, None) + return UCMConnectorMetadata() + + def request_finished( + self, + request: "Request", + block_ids: list[int], + ) -> tuple[bool, dict[str, Any] | None]: + self.requests_meta.pop(request.request_id, None) + return False, None + + def start_load_kv(self, forward_context: "ForwardContext", **kwargs: Any) -> None: + pass + + def wait_for_layer_load(self, layer_name: str) -> None: + pass + + def save_kv_layer( + self, + layer_name: str, + kv_layer: torch.Tensor, + attn_metadata: "AttentionMetadata", + **kwargs: Any, + ) -> None: + pass + + def wait_for_save(self): + pass + + def generate_hash( + self, + block_size: int, + token_ids: List[int], + parent_block_hash_value: bytes, + ) -> list[bytes]: + ret = [] + for start in range(0, len(token_ids), block_size): + end = start + block_size + block_token_ids = token_ids[start:end] + # Do not hash the block if it is not full. + if len(block_token_ids) < block_size: + break + + block_token_ids_tuple = tuple(block_token_ids) + hash_value = self.request_hasher( + (parent_block_hash_value, block_token_ids_tuple) + ) + parent_block_hash_value = hash_value + ret.append(hash_value) + + return ret + def layer_name_to_kv_cache_spec( kv_cache_config: KVCacheConfig,