-
Notifications
You must be signed in to change notification settings - Fork 86
lite #974
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
lite #974
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1250,72 +1250,127 @@ def get_num_new_matched_tokens( | |
| return expect_hit_block_num * self.block_size, False | ||
|
|
||
|
|
||
| class UCMLiteConnector(UCMDirectConnector): | ||
| class UCMLiteConnector(KVConnectorBase_V1): | ||
| def __init__( | ||
| self, | ||
| vllm_config, | ||
| role, | ||
| kv_cache_config: Optional["KVCacheConfig"] = None, | ||
| ): | ||
| ucm_config = Config(vllm_config.kv_transfer_config) | ||
| launch_config = ucm_config.get_config() | ||
| enable_record_traces = launch_config.get("enable_record_traces", False) | ||
| persist_token_threshold = launch_config.get("persist_token_threshold", 0) | ||
| vllm_config.kv_transfer_config.kv_connector_extra_config = { | ||
| "ucm_connectors": [ | ||
| { | ||
| "ucm_connector_name": "UcmPipelineStore", | ||
| "ucm_connector_config": { | ||
| "store_pipeline": "Fake", | ||
| "share_buffer_enable": True, | ||
| "buffer_number": 244032232, | ||
| }, | ||
| } | ||
| ], | ||
| "enable_record_traces": enable_record_traces, | ||
| "persist_token_threshold": persist_token_threshold, | ||
| "use_lite": True, | ||
| } | ||
| super().__init__(vllm_config, role, kv_cache_config) | ||
| self.block_size = vllm_config.cache_config.block_size | ||
| self.hash_block_size = self.block_size | ||
| self.requests_meta: dict[str, RequestMeta] = {} | ||
| self.total_block_nums = 0 | ||
| self.total_hit_block_nums = 0 | ||
|
|
||
| self.request_hasher = RequestHasher(vllm_config, 0) | ||
| self._seed = self.request_hasher("UCM_HASH_SEED") | ||
|
|
||
| super().__init__(vllm_config, role, kv_cache_config) | ||
|
|
||
| logger.info("Init UCMLiteConnector.") | ||
|
|
||
| def get_num_new_matched_tokens(self, request, num_computed_tokens): | ||
| super().get_num_new_matched_tokens(request, num_computed_tokens) | ||
|
|
||
| external_hit_blocks = 0 | ||
| req_blocks_num = len(request.all_token_ids) // self.hash_block_size | ||
| if req_blocks_num < 1: | ||
| return 0, False | ||
| self.total_block_nums += req_blocks_num | ||
| if request.request_id in self.requests_meta: | ||
| request_meta = self.requests_meta[request.request_id] | ||
| external_hit_blocks = ( | ||
| request_meta.total_hit_block_num - request_meta.hbm_hit_block_num | ||
| if request.request_id not in self.requests_meta: | ||
| hash_start = time.perf_counter() | ||
| ucm_block_ids = self.generate_hash( | ||
| self.hash_block_size, request.all_token_ids, self._seed | ||
| ) | ||
| hash_end = time.perf_counter() | ||
| hash_time_ms = (hash_end - hash_start) * 1000.0 | ||
|
|
||
| print_start = time.perf_counter() | ||
| hex_ucm_block_ids = [b.hex() for b in ucm_block_ids] | ||
| logger.info( | ||
| f"timestamp: {time.perf_counter()}, " | ||
| f"request_id: {request.request_id}, " | ||
| f"input_length: {request.num_tokens}, " | ||
| f"output_length: {request.max_tokens}, " | ||
| f"ucm_block_ids: {hex_ucm_block_ids}" | ||
| ) | ||
| print_time_ms = (time.perf_counter() - print_start) * 1000.0 | ||
| logger.info( | ||
| f"request_id: {request.request_id}, " | ||
| f"hash_time_ms: {hash_time_ms:.3f}, " | ||
| f"print_time_ms: {print_time_ms:.3f}" | ||
| ) | ||
| need_dump_blks = request_meta.ucm_block_ids[ | ||
| request_meta.total_hit_block_num : | ||
| ] | ||
| shard_indexs = [0] * len(need_dump_blks) | ||
| total_ptrs = [[0]] * len(need_dump_blks) | ||
| try: | ||
| task = self.store.dump_data(need_dump_blks, shard_indexs, total_ptrs) | ||
| self.store.wait(task) | ||
| except Exception as e: | ||
| logger.error( | ||
| f"request {request.request_id} wait dump task error. {type(e).__name__}: {e}" | ||
| ) | ||
| self.requests_meta[request.request_id] = RequestMeta() | ||
|
|
||
| self.total_hit_block_nums += external_hit_blocks | ||
| # store minimal RequestMeta for scheduler bookkeeping | ||
| self.requests_meta[request.request_id] = RequestMeta( | ||
| ucm_block_ids=ucm_block_ids, | ||
| hbm_hit_block_num=0, | ||
| total_hit_block_num=0, | ||
| num_token_ids=len(request.all_token_ids), | ||
| token_processed=0, | ||
| ) | ||
|
|
||
| self.total_block_nums += req_blocks_num | ||
|
|
||
| logger.info( | ||
| f"req external hit rate: {(external_hit_blocks / req_blocks_num):.2f}, " | ||
| f"total external hit rate: {(self.total_hit_block_nums / self.total_block_nums):.2f}" | ||
| ) | ||
| return 0, False | ||
|
qyh111 marked this conversation as resolved.
|
||
|
|
||
| def update_state_after_alloc( | ||
| self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 💡 Suggestion: The stub methods ( |
||
| ): | ||
| pass | ||
|
qyh111 marked this conversation as resolved.
|
||
|
|
||
| def build_connector_meta( | ||
| self, scheduler_output: SchedulerOutput | ||
| ) -> KVConnectorMetadata: | ||
| for request_id in scheduler_output.finished_req_ids: | ||
| self.requests_meta.pop(request_id, None) | ||
| return UCMConnectorMetadata() | ||
|
|
||
| def request_finished( | ||
| self, | ||
| request: "Request", | ||
| block_ids: list[int], | ||
| ) -> tuple[bool, dict[str, Any] | None]: | ||
| self.requests_meta.pop(request.request_id, None) | ||
| return False, None | ||
|
|
||
| def start_load_kv(self, forward_context: "ForwardContext", **kwargs: Any) -> None: | ||
| pass | ||
|
|
||
| def wait_for_layer_load(self, layer_name: str) -> None: | ||
| pass | ||
|
|
||
| def save_kv_layer( | ||
| self, | ||
| layer_name: str, | ||
| kv_layer: torch.Tensor, | ||
| attn_metadata: "AttentionMetadata", | ||
| **kwargs: Any, | ||
| ) -> None: | ||
| pass | ||
|
|
||
| def wait_for_save(self): | ||
| pass | ||
|
|
||
| def generate_hash( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| self, | ||
| block_size: int, | ||
| token_ids: List[int], | ||
| parent_block_hash_value: bytes, | ||
| ) -> list[bytes]: | ||
| ret = [] | ||
| for start in range(0, len(token_ids), block_size): | ||
| end = start + block_size | ||
| block_token_ids = token_ids[start:end] | ||
| # Do not hash the block if it is not full. | ||
| if len(block_token_ids) < block_size: | ||
| break | ||
|
|
||
| block_token_ids_tuple = tuple(block_token_ids) | ||
| hash_value = self.request_hasher( | ||
| (parent_block_hash_value, block_token_ids_tuple) | ||
| ) | ||
| parent_block_hash_value = hash_value | ||
| ret.append(hash_value) | ||
|
|
||
| return ret | ||
|
|
||
|
|
||
| def layer_name_to_kv_cache_spec( | ||
| kv_cache_config: KVCacheConfig, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.total_block_numsis only incremented when a new request is added torequests_meta. If a request already exists (e.g., after a chunked prefill), this counter won't be updated. This could lead to inconsistent statistics. Consider either incrementing for all requests or clarifying the intended behavior.