Skip to content
Open

lite #974

Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 103 additions & 48 deletions ucm/integration/vllm/ucm_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -1250,72 +1250,127 @@ def get_num_new_matched_tokens(
return expect_hit_block_num * self.block_size, False


class UCMLiteConnector(UCMDirectConnector):
class UCMLiteConnector(KVConnectorBase_V1):
def __init__(
self,
vllm_config,
role,
kv_cache_config: Optional["KVCacheConfig"] = None,
):
ucm_config = Config(vllm_config.kv_transfer_config)
launch_config = ucm_config.get_config()
enable_record_traces = launch_config.get("enable_record_traces", False)
persist_token_threshold = launch_config.get("persist_token_threshold", 0)
vllm_config.kv_transfer_config.kv_connector_extra_config = {
"ucm_connectors": [
{
"ucm_connector_name": "UcmPipelineStore",
"ucm_connector_config": {
"store_pipeline": "Fake",
"share_buffer_enable": True,
"buffer_number": 244032232,
},
}
],
"enable_record_traces": enable_record_traces,
"persist_token_threshold": persist_token_threshold,
"use_lite": True,
}
super().__init__(vllm_config, role, kv_cache_config)
self.block_size = vllm_config.cache_config.block_size
self.hash_block_size = self.block_size
self.requests_meta: dict[str, RequestMeta] = {}
self.total_block_nums = 0
self.total_hit_block_nums = 0

self.request_hasher = RequestHasher(vllm_config, 0)
self._seed = self.request_hasher("UCM_HASH_SEED")

super().__init__(vllm_config, role, kv_cache_config)

logger.info("Init UCMLiteConnector.")

def get_num_new_matched_tokens(self, request, num_computed_tokens):
super().get_num_new_matched_tokens(request, num_computed_tokens)

external_hit_blocks = 0
req_blocks_num = len(request.all_token_ids) // self.hash_block_size
if req_blocks_num < 1:
return 0, False
self.total_block_nums += req_blocks_num
if request.request_id in self.requests_meta:
request_meta = self.requests_meta[request.request_id]
external_hit_blocks = (
request_meta.total_hit_block_num - request_meta.hbm_hit_block_num
if request.request_id not in self.requests_meta:
hash_start = time.perf_counter()
ucm_block_ids = self.generate_hash(
self.hash_block_size, request.all_token_ids, self._seed
)
hash_end = time.perf_counter()
hash_time_ms = (hash_end - hash_start) * 1000.0

print_start = time.perf_counter()
hex_ucm_block_ids = [b.hex() for b in ucm_block_ids]
logger.info(
f"timestamp: {time.perf_counter()}, "
f"request_id: {request.request_id}, "
f"input_length: {request.num_tokens}, "
f"output_length: {request.max_tokens}, "
f"ucm_block_ids: {hex_ucm_block_ids}"
)
print_time_ms = (time.perf_counter() - print_start) * 1000.0
logger.info(
f"request_id: {request.request_id}, "
f"hash_time_ms: {hash_time_ms:.3f}, "
f"print_time_ms: {print_time_ms:.3f}"
)
need_dump_blks = request_meta.ucm_block_ids[
request_meta.total_hit_block_num :
]
shard_indexs = [0] * len(need_dump_blks)
total_ptrs = [[0]] * len(need_dump_blks)
try:
task = self.store.dump_data(need_dump_blks, shard_indexs, total_ptrs)
self.store.wait(task)
except Exception as e:
logger.error(
f"request {request.request_id} wait dump task error. {type(e).__name__}: {e}"
)
self.requests_meta[request.request_id] = RequestMeta()

self.total_hit_block_nums += external_hit_blocks
# store minimal RequestMeta for scheduler bookkeeping
self.requests_meta[request.request_id] = RequestMeta(
ucm_block_ids=ucm_block_ids,
hbm_hit_block_num=0,
total_hit_block_num=0,
num_token_ids=len(request.all_token_ids),
token_processed=0,
)

self.total_block_nums += req_blocks_num

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Warning: self.total_block_nums is only incremented when a new request is added to requests_meta. If a request already exists (e.g., after a chunked prefill), this counter won't be updated. This could lead to inconsistent statistics. Consider either incrementing for all requests or clarifying the intended behavior.


logger.info(
f"req external hit rate: {(external_hit_blocks / req_blocks_num):.2f}, "
f"total external hit rate: {(self.total_hit_block_nums / self.total_block_nums):.2f}"
)
return 0, False
Comment thread
qyh111 marked this conversation as resolved.

def update_state_after_alloc(
self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Suggestion: The stub methods (update_state_after_alloc, start_load_kv, wait_for_layer_load, save_kv_layer, wait_for_save) should have docstrings explaining that they are intentionally empty for the lite connector, or raise NotImplementedError if they are not meant to be called.

):
pass
Comment thread
qyh111 marked this conversation as resolved.

def build_connector_meta(
self, scheduler_output: SchedulerOutput
) -> KVConnectorMetadata:
for request_id in scheduler_output.finished_req_ids:
self.requests_meta.pop(request_id, None)
return UCMConnectorMetadata()

def request_finished(
self,
request: "Request",
block_ids: list[int],
) -> tuple[bool, dict[str, Any] | None]:
self.requests_meta.pop(request.request_id, None)
return False, None

def start_load_kv(self, forward_context: "ForwardContext", **kwargs: Any) -> None:
pass

def wait_for_layer_load(self, layer_name: str) -> None:
pass

def save_kv_layer(
self,
layer_name: str,
kv_layer: torch.Tensor,
attn_metadata: "AttentionMetadata",
**kwargs: Any,
) -> None:
pass

def wait_for_save(self):
pass

def generate_hash(

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Warning: The generate_hash method uses self.request_hasher which could raise exceptions. There's no error handling here. If the hasher fails, the exception will propagate up and could crash the request processing. Consider adding a try-except block or documenting the expected behavior.

self,
block_size: int,
token_ids: List[int],
parent_block_hash_value: bytes,
) -> list[bytes]:
ret = []
for start in range(0, len(token_ids), block_size):
end = start + block_size
block_token_ids = token_ids[start:end]
# Do not hash the block if it is not full.
if len(block_token_ids) < block_size:
break

block_token_ids_tuple = tuple(block_token_ids)
hash_value = self.request_hasher(
(parent_block_hash_value, block_token_ids_tuple)
)
parent_block_hash_value = hash_value
ret.append(hash_value)

return ret


def layer_name_to_kv_cache_spec(
kv_cache_config: KVCacheConfig,
Expand Down
Loading