Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions flexkv/cache/cache_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,8 @@ class CacheStrategy:

DEFAULT_CACHE_STRATEGY = CacheStrategy()

CPUONLY_CACHE_STRATEGY = CacheStrategy(ignore_gpu=False, ignore_ssd=True, ignore_remote=True, ignore_gds=True)

class GlobalCacheEngine:
def __init__(self, cache_config: CacheConfig, model_config: ModelConfig, redis_meta: RedisMeta = None,
event_collector: Optional[KVEventCollector] = None):
Expand Down
3 changes: 3 additions & 0 deletions flexkv/kvmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ def get_match(self,
token_mask: Optional[Union[torch.Tensor, np.ndarray]] = None,
layer_granularity: int = -1,
dp_id: int = 0,
cpu_only: bool = False,
namespace: Optional[List[str]] = None,
) -> Tuple[int, np.ndarray]:
if isinstance(token_ids, torch.Tensor):
Expand All @@ -185,12 +186,14 @@ def get_match(self,
task_id, mask = self.dp_client.get_match(token_ids,
token_mask,
layer_granularity,
cpu_only=cpu_only,
namespace=namespace)
else:
task_id, mask = self.kv_task_engine.get_match(token_ids,
token_mask,
layer_granularity,
dp_id,
cpu_only=cpu_only,
namespace=namespace)
return task_id, mask

Expand Down
16 changes: 15 additions & 1 deletion flexkv/kvtask.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,12 @@
from flexkv.common.block import hash_token
from flexkv.common.transfer import TransferOpGraph, merge_to_batch_graph, get_nvtx_default_color, CompletedOp
from flexkv.common.tracer import FlexKVTracer
from flexkv.cache.cache_engine import GlobalCacheEngine, DEFAULT_CACHE_STRATEGY
from flexkv.cache.cache_engine import (
GlobalCacheEngine,
CacheStrategy,
DEFAULT_CACHE_STRATEGY,
CPUONLY_CACHE_STRATEGY,
)
from flexkv.transfer_manager import TransferManagerHandle, TransferManagerOnRemote
from flexkv.common.request import KVResponseStatus, KVResponse
from flexkv.transfer_manager import (
Expand Down Expand Up @@ -226,6 +231,7 @@ def create_get_task(self,
layer_granularity: int = -1,
dp_id: int = 0,
is_fake_slot_mapping: bool = False,
temp_cache_strategy: CacheStrategy = DEFAULT_CACHE_STRATEGY,
namespace: Optional[List[str]] = None,
) -> None:
if task_id in self.tasks:
Expand All @@ -238,6 +244,7 @@ def create_get_task(self,
layer_num=self.model_config.num_layers,
layer_granularity=layer_granularity,
dp_id=dp_id,
temp_cache_strategy=temp_cache_strategy,
namespace=namespace)
self.tasks[task_id] = KVTask(
task_id=task_id,
Expand Down Expand Up @@ -696,6 +703,7 @@ def get_match(self,
token_mask: Optional[np.ndarray] = None,
layer_granularity: int = -1,
dp_id: int = 0,
cpu_only: bool = False,
task_id: int = -1,
namespace: Optional[List[str]] = None) -> Tuple[int, np.ndarray]:
nvtx.push_range(f"get match: task_id={task_id}", color=get_nvtx_default_color())
Expand All @@ -709,6 +717,7 @@ def get_match(self,
token_mask=token_mask,
layer_granularity=layer_granularity,
dp_id=dp_id,
cpu_only=cpu_only,
task_id=task_id,
namespace=namespace)
# trace get match request
Expand All @@ -731,6 +740,7 @@ def _get_match_impl(self,
token_mask: Optional[np.ndarray] = None,
layer_granularity: int = -1,
dp_id: int = 0,
cpu_only: bool = False,
task_id: int = -1,
namespace: Optional[List[str]] = None) -> Tuple[int, np.ndarray]:
if token_mask is None:
Expand All @@ -739,6 +749,9 @@ def _get_match_impl(self,
layer_granularity = self.model_config.num_layers
if task_id == -1:
task_id = self._gen_task_id()
temp_cache_strategy = DEFAULT_CACHE_STRATEGY
if cpu_only:
temp_cache_strategy = CPUONLY_CACHE_STRATEGY
nvtx.push_range(f"get match: task_id={task_id}", color=get_nvtx_default_color())
self.create_get_task(task_id,
token_ids,
Expand All @@ -747,6 +760,7 @@ def _get_match_impl(self,
layer_granularity,
dp_id,
is_fake_slot_mapping=is_fake_slot_mapping,
temp_cache_strategy=temp_cache_strategy,
namespace=namespace)
self._process_empty_graph(task_id)
nvtx.pop_range()
Expand Down
2 changes: 2 additions & 0 deletions flexkv/server/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,12 +156,14 @@ def get_match(
token_ids: np.ndarray,
token_mask: Optional[np.ndarray],
layer_granularity: int,
cpu_only: bool = False,
namespace: Optional[List[str]] = None,
) -> Optional[Tuple[int, np.ndarray]]:
req = GetMatchRequest(self.dp_client_id,
token_ids,
token_mask if token_mask is not None else None,
layer_granularity,
cpu_only,
self._get_task_id(),
namespace)
self.send_to_server.send_pyobj(req)
Expand Down
1 change: 1 addition & 0 deletions flexkv/server/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ class GetMatchRequest:
token_ids: np.ndarray
token_mask: Optional[np.ndarray]
layer_granularity: int
cpu_only: bool = False
task_id: int = -1
namespace: Optional[List[str]] = None

Expand Down
1 change: 1 addition & 0 deletions flexkv/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,7 @@ def _handle_get_match_request(self, req: GetMatchRequest) -> None:
token_mask=req.token_mask,
layer_granularity=req.layer_granularity,
dp_id=req.dp_client_id,
cpu_only=req.cpu_only,
task_id=req.task_id,
namespace=req.namespace,
)
Expand Down
Loading