diff --git a/docker/Dockerfile b/docker/Dockerfile index 8f73a603cc..e766107ae7 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -3,7 +3,7 @@ FROM nvidia/cuda:${CUDA_VERSION}-cudnn-devel-ubuntu22.04 ARG PYTHON_VERSION=3.10 ARG MAMBA_VERSION=24.7.1-0 -ARG VLLM_VERSION=0.11.0 +ARG VLLM_VERSION=0.16.0 ARG TARGETPLATFORM ARG ENABLE_DEEPEP=1 ARG ENABLE_NIXL=1 @@ -20,14 +20,15 @@ RUN chmod 777 -R /tmp && \ curl \ g++ \ make \ - git && \ + git \ + wget && \ rm -rf /var/lib/apt/lists/* RUN case ${TARGETPLATFORM} in \ "linux/arm64") MAMBA_ARCH=aarch64 ;; \ *) MAMBA_ARCH=x86_64 ;; \ esac && \ - curl -fsSL -o ~/mambaforge.sh "https://github.com/conda-forge/miniforge/releases/download/${MAMBA_VERSION}/Mambaforge-${MAMBA_VERSION}-Linux-${MAMBA_ARCH}.sh" && \ + wget -O ~/mambaforge.sh "https://github.com/conda-forge/miniforge/releases/download/${MAMBA_VERSION}/Mambaforge-${MAMBA_VERSION}-Linux-${MAMBA_ARCH}.sh" && \ bash ~/mambaforge.sh -b -p /opt/conda && \ rm ~/mambaforge.sh @@ -44,7 +45,6 @@ COPY ./requirements.txt /lightllm/requirements.txt RUN pip install -U pip RUN pip install -r /lightllm/requirements.txt --no-cache-dir RUN pip install --no-cache-dir vllm==${VLLM_VERSION} -RUN pip install https://github.com/ModelTC/LightKernel/releases/download/v1.0.1/lightllm_kernel-0.1.0-cp310-cp310-linux_x86_64.whl RUN apt-get update && apt-get install -y libnuma-dev && rm -rf /var/lib/apt/lists/* diff --git a/docker/scripts/build.sh b/docker/scripts/build.sh index 1699b39dd7..355d6c65b3 100644 --- a/docker/scripts/build.sh +++ b/docker/scripts/build.sh @@ -100,5 +100,6 @@ DOCKER_BUILDKIT=1 docker build -f docker/Dockerfile \ --build-arg ENABLE_DEEPEP="${ENABLE_DEEPEP}" \ --build-arg ENABLE_NIXL="${ENABLE_NIXL}" \ --build-arg ENABLE_CACHE="${ENABLE_CACHE}" \ - -t "${IMAGE_PREFIX}:${IMAGE_TAG}" . + --progress=plain \ + -t "${IMAGE_PREFIX}:${IMAGE_TAG}" . diff --git a/lightllm/common/basemodel/attention/__init__.py b/lightllm/common/basemodel/attention/__init__.py index 80df545498..0eea52cc89 100644 --- a/lightllm/common/basemodel/attention/__init__.py +++ b/lightllm/common/basemodel/attention/__init__.py @@ -10,9 +10,14 @@ from .flashinfer.fp import FlashInferAttBackend from .flashinfer.mla import MlaFlashInferAttBackend +# NSA backend +from .nsa.flashmla_sparse import NsaFlashMlaSparseAttBackend + from .create_utils import ( get_prefill_att_backend_class, get_decode_att_backend_class, get_mla_prefill_att_backend_class, get_mla_decode_att_backend_class, + get_nsa_prefill_att_backend_class, + get_nsa_decode_att_backend_class, ) diff --git a/lightllm/common/basemodel/attention/base_att.py b/lightllm/common/basemodel/attention/base_att.py index 859d97ca84..1286a46ec2 100644 --- a/lightllm/common/basemodel/attention/base_att.py +++ b/lightllm/common/basemodel/attention/base_att.py @@ -65,6 +65,11 @@ class AttControl: mla_prefill_dict: Dict = None mla_decode: bool = False mla_decode_dict: Dict = None + # nsa (native sparse attention) 专用传参项 + nsa_prefill: bool = False + nsa_prefill_dict: Dict = None + nsa_decode: bool = False + nsa_decode_dict: Dict = None @dataclass diff --git a/lightllm/common/basemodel/attention/create_utils.py b/lightllm/common/basemodel/attention/create_utils.py index 19252cf13a..1fcde2a5ca 100644 --- a/lightllm/common/basemodel/attention/create_utils.py +++ b/lightllm/common/basemodel/attention/create_utils.py @@ -1,10 +1,8 @@ """Attention backend selection utilities.""" - -import os -import torch from lightllm.utils.envs_utils import get_env_start_args from lightllm.utils.log_utils import init_logger from lightllm.utils.backend_validator import validate +from typing import Dict from .base_att import BaseAttBackend from .triton.fp import TritonAttBackend from .triton.int4kv import Int4kvTritonAttBackend @@ -16,6 +14,7 @@ from .flashinfer.fp8 import Fp8FlashInferAttBackend from .flashinfer.fp import FlashInferAttBackend from .flashinfer.mla import MlaFlashInferAttBackend +from .nsa.flashmla_sparse import NsaFlashMlaSparseAttBackend logger = init_logger(__name__) @@ -46,16 +45,25 @@ }, } +nsa_data_type_to_backend = { + "None": { + "flashmla_sparse": NsaFlashMlaSparseAttBackend, + # Future backends: "fa3", "tilelang", "aiter" + }, +} + def _auto_select_backend( - llm_dtype: str, is_mla: bool = False, priority_list: list = ["fa3", "flashinfer", "triton"] + llm_dtype: str, + kv_type_to_backend: Dict[str, Dict[str, BaseAttBackend]], + priority_list: list = ["fa3", "flashinfer", "triton"], ) -> type: """Auto-select the best available backend with validation. Priority: FA3 > FlashInfer > Triton Each backend is validated in a subprocess with ground truth checks. """ - backend_map = mla_data_type_to_backend if is_mla else data_type_to_backend + backend_map = kv_type_to_backend for backend_name in priority_list: if validate(backend_name): @@ -74,7 +82,7 @@ def get_prefill_att_backend_class(index=0, priority_list: list = ["fa3", "flashi if backend_str != "auto": return data_type_to_backend[llm_dtype][backend_str] else: - return _auto_select_backend(llm_dtype, is_mla=False, priority_list=priority_list) + return _auto_select_backend(llm_dtype, kv_type_to_backend=data_type_to_backend, priority_list=priority_list) def get_decode_att_backend_class(index=0, priority_list: list = ["fa3", "flashinfer", "triton"]) -> BaseAttBackend: @@ -84,7 +92,7 @@ def get_decode_att_backend_class(index=0, priority_list: list = ["fa3", "flashin if backend_str != "auto": return data_type_to_backend[llm_dtype][backend_str] else: - return _auto_select_backend(llm_dtype, is_mla=False, priority_list=priority_list) + return _auto_select_backend(llm_dtype, kv_type_to_backend=data_type_to_backend, priority_list=priority_list) def get_mla_prefill_att_backend_class(index=0, priority_list: list = ["fa3", "flashinfer", "triton"]) -> BaseAttBackend: @@ -94,7 +102,7 @@ def get_mla_prefill_att_backend_class(index=0, priority_list: list = ["fa3", "fl if backend_str != "auto": return mla_data_type_to_backend[llm_dtype][backend_str] else: - return _auto_select_backend(llm_dtype, is_mla=True, priority_list=priority_list) + return _auto_select_backend(llm_dtype, kv_type_to_backend=mla_data_type_to_backend, priority_list=priority_list) def get_mla_decode_att_backend_class(index=0, priority_list: list = ["fa3", "flashinfer", "triton"]) -> BaseAttBackend: @@ -104,4 +112,24 @@ def get_mla_decode_att_backend_class(index=0, priority_list: list = ["fa3", "fla if backend_str != "auto": return mla_data_type_to_backend[llm_dtype][backend_str] else: - return _auto_select_backend(llm_dtype, is_mla=True, priority_list=priority_list) + return _auto_select_backend(llm_dtype, kv_type_to_backend=mla_data_type_to_backend, priority_list=priority_list) + + +def get_nsa_prefill_att_backend_class(index=0, priority_list: list = ["flashmla_sparse"]) -> BaseAttBackend: + args = get_env_start_args() + llm_dtype = args.llm_kv_type + backend_str = args.llm_prefill_att_backend[index] + if backend_str != "auto": + return nsa_data_type_to_backend[llm_dtype][backend_str] + else: + return _auto_select_backend(llm_dtype, kv_type_to_backend=nsa_data_type_to_backend, priority_list=priority_list) + + +def get_nsa_decode_att_backend_class(index=0, priority_list: list = ["flashmla_sparse"]) -> BaseAttBackend: + args = get_env_start_args() + llm_dtype = args.llm_kv_type + backend_str = args.llm_decode_att_backend[index] + if backend_str != "auto": + return nsa_data_type_to_backend[llm_dtype][backend_str] + else: + return _auto_select_backend(llm_dtype, kv_type_to_backend=nsa_data_type_to_backend, priority_list=priority_list) diff --git a/lightllm/common/basemodel/attention/nsa/__init__.py b/lightllm/common/basemodel/attention/nsa/__init__.py new file mode 100644 index 0000000000..11a1ebfdcd --- /dev/null +++ b/lightllm/common/basemodel/attention/nsa/__init__.py @@ -0,0 +1,13 @@ +"""NSA (Native Sparse Attention) backend implementations.""" + +from .flashmla_sparse import ( + NsaFlashMlaSparseAttBackend, + NsaFlashMlaSparsePrefillAttState, + NsaFlashMlaSparseDecodeAttState, +) + +__all__ = [ + "NsaFlashMlaSparseAttBackend", + "NsaFlashMlaSparsePrefillAttState", + "NsaFlashMlaSparseDecodeAttState", +] diff --git a/lightllm/common/basemodel/attention/nsa/flashmla_sparse.py b/lightllm/common/basemodel/attention/nsa/flashmla_sparse.py new file mode 100644 index 0000000000..1cc2665c60 --- /dev/null +++ b/lightllm/common/basemodel/attention/nsa/flashmla_sparse.py @@ -0,0 +1,192 @@ +# Adapted from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/attention/nsa_backend.py +# Uses sgl_kernel.flash_mla and sgl_kernel.flash_attn from the sglang kernel library. + +import dataclasses +import torch +from typing import Tuple, TYPE_CHECKING + +from ..base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl +from lightllm.utils.dist_utils import get_current_device_id + +if TYPE_CHECKING: + from lightllm.common.basemodel.infer_struct import InferStateInfo + + +class NsaFlashMlaSparseAttBackend(BaseAttBackend): + def __init__(self, model): + super().__init__(model=model) + device = get_current_device_id() + self.ragged_mem_buffers = [ + torch.empty(model.graph_max_batch_size * model.max_seq_length, dtype=torch.int32, device=device) + for _ in range(2) + ] + + def create_att_prefill_state(self, infer_state: "InferStateInfo") -> "NsaFlashMlaSparsePrefillAttState": + return NsaFlashMlaSparsePrefillAttState(backend=self, infer_state=infer_state) + + def create_att_decode_state(self, infer_state: "InferStateInfo") -> "NsaFlashMlaSparseDecodeAttState": + return NsaFlashMlaSparseDecodeAttState(backend=self, infer_state=infer_state) + + +@dataclasses.dataclass +class NsaFlashMlaSparsePrefillAttState(BasePrefillAttState): + """Prefill attention state for NSA using flash_mla_sparse_fwd.""" + + ks: torch.Tensor = None + ke: torch.Tensor = None + lengths: torch.Tensor = None + ragged_mem_index: torch.Tensor = None + + def init_state(self): + self.backend: NsaFlashMlaSparseAttBackend = self.backend + self.ragged_mem_index = torch.empty( + self.infer_state.total_token_num, + dtype=torch.int32, + device=get_current_device_id(), + ) + from lightllm.common.basemodel.triton_kernel.gen_nsa_ks_ke import gen_nsa_ks_ke + + self.ks, self.ke, self.lengths = gen_nsa_ks_ke( + b_seq_len=self.infer_state.b_seq_len, + b_q_seq_len=self.infer_state.b_q_seq_len, + b_req_idx=self.infer_state.b_req_idx, + req_to_token_index=self.infer_state.req_manager.req_to_token_indexs, + q_token_num=self.infer_state.total_token_num - self.infer_state.prefix_total_token_num, + ragged_mem_index=self.ragged_mem_index, + hold_req_idx=self.infer_state.req_manager.HOLD_REQUEST_ID, + ) + return + + def prefill_att( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + att_control: AttControl = AttControl(), + alloc_func=torch.empty, + ) -> torch.Tensor: + assert att_control.nsa_prefill, "nsa_prefill must be True for NSA prefill attention" + assert att_control.nsa_prefill_dict is not None, "nsa_prefill_dict is required" + + return self._nsa_prefill_att(q=q, kv=k, att_control=att_control) + + def _nsa_prefill_att( + self, + q: torch.Tensor, + kv: torch.Tensor, + att_control: AttControl, + ) -> torch.Tensor: + from sgl_kernel.flash_mla import flash_mla_sparse_fwd + + nsa_dict = att_control.nsa_prefill_dict + topk_indices = nsa_dict["topk_indices"] + softmax_scale = nsa_dict["softmax_scale"] + kv_lora_rank = nsa_dict["kv_lora_rank"] + + if topk_indices.ndim == 2: + topk_indices = topk_indices.unsqueeze(1) + + mla_out, _, _ = flash_mla_sparse_fwd( + q=q, + kv=kv, + indices=topk_indices, + sm_scale=softmax_scale, + d_v=kv_lora_rank, + ) + return mla_out + + +@dataclasses.dataclass +class NsaFlashMlaSparseDecodeAttState(BaseDecodeAttState): + + ks: torch.Tensor = None + ke: torch.Tensor = None + length: torch.Tensor = None + ragged_mem_index: torch.Tensor = None + nsa_cache_seqlens: torch.Tensor = None + nsa_cu_seqlens_k_new: torch.Tensor = None + + def init_state(self): + self.backend: NsaFlashMlaSparseAttBackend = self.backend + model = self.backend.model + use_cuda_graph = ( + self.infer_state.batch_size <= model.graph_max_batch_size + and self.infer_state.max_kv_seq_len <= model.graph_max_len_in_batch + ) + + if use_cuda_graph: + self.ragged_mem_index = self.backend.ragged_mem_buffers[self.infer_state.microbatch_index] + else: + self.ragged_mem_index = torch.empty( + self.infer_state.total_token_num, + dtype=torch.int32, + device=get_current_device_id(), + ) + + from lightllm.common.basemodel.triton_kernel.gen_nsa_ks_ke import gen_nsa_ks_ke + + self.ks, self.ke, self.lengths = gen_nsa_ks_ke( + b_seq_len=self.infer_state.b_seq_len, + b_q_seq_len=self.infer_state.b_q_seq_len, + b_req_idx=self.infer_state.b_req_idx, + req_to_token_index=self.infer_state.req_manager.req_to_token_indexs, + q_token_num=self.infer_state.b_seq_len.shape[0], + ragged_mem_index=self.ragged_mem_index, + hold_req_idx=self.infer_state.req_manager.HOLD_REQUEST_ID, + ) + self.nsa_cache_seqlens = torch.minimum( + torch.full(size=(self.infer_state.batch_size,), fill_value=2048, dtype=torch.int32, device="cuda"), + self.infer_state.b_seq_len, + ) + padded_seq_lens = torch.zeros(size=(self.nsa_cache_seqlens.shape[0] + 1,), dtype=torch.int32, device="cuda") + # 进行 cumsum 操作 + padded_seq_lens[1:].copy_(self.nsa_cache_seqlens, non_blocking=True) + self.nsa_cu_seqlens_k_new = padded_seq_lens.cumsum(dim=0, dtype=torch.int32) + + def decode_att( + self, + q: Tuple[torch.Tensor, torch.Tensor], + k: torch.Tensor, + v: torch.Tensor, + att_control: AttControl = AttControl(), + alloc_func=torch.empty, + ) -> torch.Tensor: + assert att_control.nsa_decode, "nsa_decode must be True for NSA decode attention" + assert att_control.nsa_decode_dict is not None, "nsa_decode_dict is required" + + return self._nsa_decode_att(q=q, kv=k, att_control=att_control) + + def _nsa_decode_att( + self, + q: Tuple[torch.Tensor, torch.Tensor], + kv: torch.Tensor, + att_control: AttControl, + ) -> torch.Tensor: + from sgl_kernel.flash_attn import flash_attn_with_kvcache + + nsa_dict = att_control.nsa_decode_dict + topk_indices = nsa_dict["topk_indices"] + softmax_scale = nsa_dict["softmax_scale"] + kv_lora_rank = nsa_dict["kv_lora_rank"] + qk_rope_head_dim = nsa_dict["qk_rope_head_dim"] + + q_nope, q_rope = q + + # Extract k_rope and kv_nope from the KV buffer + k_rope = kv[:, :, -qk_rope_head_dim:].view(-1, 1, 1, qk_rope_head_dim) + kv_nope = kv[:, :, :-qk_rope_head_dim].view(-1, 1, 1, kv_lora_rank) + + o_tensor = flash_attn_with_kvcache( + q=q_rope, + k_cache=k_rope, + v_cache=kv_nope, + qv=q_nope, + page_table=topk_indices, + cache_seqlens=self.nsa_cache_seqlens, + cu_seqlens_q=self.infer_state.b1_cu_q_seq_len, + cu_seqlens_k_new=self.nsa_cu_seqlens_k_new, + max_seqlen_q=self.infer_state.max_q_seq_len, + softmax_scale=softmax_scale, + causal=True, + ) + return o_tensor diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py index 6bcf7fc03c..8f54e14a72 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py @@ -98,7 +98,7 @@ def _init_parallel_params(self): self.split_inter_size = self.moe_intermediate_size // self.tp_world_size_ if self.enable_ep_moe: assert self.num_fused_shared_experts == 0, "num_fused_shared_experts must be 0 when enable_ep_moe" - logger.info( + logger.debug( f"global_rank {self.global_rank_} layerindex {self.layer_num_} " f"redundancy_expertids: {self.redundancy_expert_ids}" ) diff --git a/lightllm/common/basemodel/triton_kernel/gen_nsa_ks_ke.py b/lightllm/common/basemodel/triton_kernel/gen_nsa_ks_ke.py new file mode 100644 index 0000000000..4edad6177a --- /dev/null +++ b/lightllm/common/basemodel/triton_kernel/gen_nsa_ks_ke.py @@ -0,0 +1,166 @@ +import torch +import torch.nn.functional as F +import triton +import triton.language as tl +from typing import Optional + + +@triton.jit +def _gen_nsa_ks_ke( + b_seq_len, + b_req_idx, + b_q_seq_len, + b_same_req_mark, + ks, + ke, + lengths, + req_to_token_index, + strided_req_to_token_index_b, + strided_req_to_token_index_s, + ragged_mem_index, + BLOCK_REQ: tl.constexpr, + BLOCK_SEQ_SPLIT: tl.constexpr, +): + cur_index = tl.program_id(0) + # 只处理最后一个同样req_idx的req进行处理,代表seq_len最长的那个。 + # req_mark 为 0,表示不是最后一个。 + req_mark = tl.load(b_same_req_mark + cur_index) + if req_mark == 0: + return + + off = tl.arange(0, BLOCK_REQ) + b_same_req_mark = tl.load(b_same_req_mark + off, off < cur_index, other=0) + pre_b_seq_len_data = tl.load(b_seq_len + off, (off < cur_index) & (b_same_req_mark != 0), other=0) + pre_sum_seq_len = tl.sum(pre_b_seq_len_data) + + # 兼容 prefill 和 decode 的情况, decode 可能存在 mtp 的情况,各个请求会共享一个req对象,其处理比较特殊 + q_seq_len = tl.load(b_q_seq_len + cur_index) + req_mark - 1 + cur_total_len = tl.load(b_seq_len + cur_index) + cur_req_idx = tl.load(b_req_idx + cur_index) + + b_q_seq_len_data = tl.load(b_q_seq_len + off, (off < (cur_index - req_mark + 1)), other=0) + store_start_index = tl.sum(b_q_seq_len_data) + + for block_index in range(tl.cdiv(q_seq_len, BLOCK_SEQ_SPLIT)): + block_start = block_index * BLOCK_SEQ_SPLIT + block_end = min(q_seq_len, (block_index + 1) * BLOCK_SEQ_SPLIT) + block_off = block_start + tl.arange(0, BLOCK_SEQ_SPLIT) + mask = block_off < block_end + ks_data = tl.zeros((BLOCK_SEQ_SPLIT,), dtype=tl.int32) + ke_data = (cur_total_len - q_seq_len) + tl.arange(0, BLOCK_SEQ_SPLIT) + + tl.store( + ks + store_start_index + block_off, + ks_data + pre_sum_seq_len, + mask=mask, + ) + tl.store( + ke + store_start_index + block_off, + ke_data + pre_sum_seq_len, + mask=mask, + ) + tl.store( + lengths + store_start_index + block_off, + ke_data - ks_data + 1, + mask=mask, + ) + + for block_index in range(tl.cdiv(cur_total_len, BLOCK_SEQ_SPLIT)): + block_start = block_index * BLOCK_SEQ_SPLIT + block_end = min(cur_total_len, (block_index + 1) * BLOCK_SEQ_SPLIT) + mask = block_start + tl.arange(0, BLOCK_SEQ_SPLIT) < block_end + + src_mem_index_ptr = ( + req_to_token_index + + strided_req_to_token_index_b * cur_req_idx + + block_start + + tl.arange(0, BLOCK_SEQ_SPLIT) + ) + src_mem_index = tl.load(src_mem_index_ptr, mask=mask, other=-1) + tl.store( + ragged_mem_index + pre_sum_seq_len + block_start + tl.arange(0, BLOCK_SEQ_SPLIT), src_mem_index, mask=mask + ) + return + + +@torch.no_grad() +def gen_nsa_ks_ke( + b_seq_len: torch.Tensor, + b_q_seq_len: torch.Tensor, + b_req_idx: torch.Tensor, + req_to_token_index: torch.Tensor, + q_token_num: int, + ragged_mem_index: torch.Tensor, + hold_req_idx: int = -1, +): + """ + hold_req_idx 这是一个特殊req idx,主要是用于padding 请求数量时使用,所以其处理存在特殊性。 + """ + batch_size = b_seq_len.shape[0] + ks = torch.empty((q_token_num,), dtype=torch.int32, device=b_seq_len.device) + ke = torch.empty((q_token_num,), dtype=torch.int32, device=b_seq_len.device) + lengths = torch.empty((q_token_num,), dtype=torch.int32, device=b_seq_len.device) + b_same_req_mark = gen_same_req_mark(b_req_idx, hold_req_idx=hold_req_idx) + + _gen_nsa_ks_ke[(batch_size,)]( + b_seq_len=b_seq_len, + b_req_idx=b_req_idx, + b_q_seq_len=b_q_seq_len, + b_same_req_mark=b_same_req_mark, + ks=ks, + ke=ke, + lengths=lengths, + req_to_token_index=req_to_token_index, + strided_req_to_token_index_b=req_to_token_index.stride(0), + strided_req_to_token_index_s=req_to_token_index.stride(1), + ragged_mem_index=ragged_mem_index, + BLOCK_REQ=triton.next_power_of_2(batch_size), + BLOCK_SEQ_SPLIT=256, + ) + return ks, ke, lengths + + +@triton.jit +def _gen_same_req_mark(b_req_idx, b_same_req_mark, hold_req_idx, BLOCK_SIZE: tl.constexpr): + cur_index = tl.program_id(0) + cur_req_idx = tl.load(b_req_idx + cur_index) + # hold req idx 可能重复,但是单独成组。 + if cur_req_idx == hold_req_idx: + tl.store(b_same_req_mark + cur_index, 1) + return + + off = tl.arange(0, BLOCK_SIZE) + pre_req_idxs = tl.load(b_req_idx + off, (off < cur_index) & (off < tl.num_programs(0)), other=-1) + after_req_idxs = tl.load(b_req_idx + off, (off > cur_index) & (off < tl.num_programs(0)), other=-1) + + pre_idx_count = tl.sum(pre_req_idxs == cur_req_idx) + after_idx_count = tl.sum(after_req_idxs == cur_req_idx) + + has_mark = tl.where(after_idx_count == 0, 1, 0) + for _ in range(has_mark): + tl.store(b_same_req_mark + cur_index, pre_idx_count + 1) + return + + +@torch.no_grad() +def gen_same_req_mark(b_req_idx: torch.Tensor, hold_req_idx: int = -1): + """ + b_req_idx: torch.Tensor + hold_req_idx: int default is -1, hold_req_idx is used to pad batch size to cuda graph batch size, + so need special handle. + out: torch.Tensor + + demo: + b_req_idx = [1, 1, 2, 3, 3, 3] + out = [0, 2, 1, 0, 0, 3] + """ + batch_size = b_req_idx.shape[0] + b_same_req_mark = torch.empty((batch_size,), dtype=torch.int32, device=b_req_idx.device) + b_same_req_mark.fill_(0) + _gen_same_req_mark[(batch_size,)]( + b_req_idx=b_req_idx, + b_same_req_mark=b_same_req_mark, + hold_req_idx=hold_req_idx, + BLOCK_SIZE=triton.next_power_of_2(batch_size), + ) + return b_same_req_mark diff --git a/lightllm/common/basemodel/triton_kernel/quantization/fp8act_quant_kernel.py b/lightllm/common/basemodel/triton_kernel/quantization/fp8act_quant_kernel.py index 5402a9caf1..0a68372887 100644 --- a/lightllm/common/basemodel/triton_kernel/quantization/fp8act_quant_kernel.py +++ b/lightllm/common/basemodel/triton_kernel/quantization/fp8act_quant_kernel.py @@ -154,7 +154,7 @@ def per_token_group_quant_fp8( ) # 使用SGL kernel进行量化 - sgl_ops.sgl_per_token_group_quant_fp8(x, x_q, x_s, group_size, 1e-10, fp8_min, fp8_max, False) + sgl_ops.sgl_per_token_group_quant_fp8(x, x_q, x_s, group_size, 1e-10, fp8_min, fp8_max, False, enable_v2=True) else: # 使用LightLLM kernel进行量化 x_s = alloc_func( diff --git a/lightllm/common/kv_cache_mem_manager/__init__.py b/lightllm/common/kv_cache_mem_manager/__init__.py index 7d516e6728..d41d1555a3 100644 --- a/lightllm/common/kv_cache_mem_manager/__init__.py +++ b/lightllm/common/kv_cache_mem_manager/__init__.py @@ -4,6 +4,7 @@ from .ppl_int8kv_mem_manager import PPLINT8KVMemoryManager from .ppl_int4kv_mem_manager import PPLINT4KVMemoryManager from .deepseek2_mem_manager import Deepseek2MemoryManager +from .deepseek3_2mem_manager import Deepseek3_2MemoryManager __all__ = [ "MemoryManager", @@ -13,4 +14,5 @@ "PPLINT4KVMemoryManager", "PPLINT8KVMemoryManager", "Deepseek2MemoryManager", + "Deepseek3_2MemoryManager", ] diff --git a/lightllm/common/kv_cache_mem_manager/deepseek3_2mem_manager.py b/lightllm/common/kv_cache_mem_manager/deepseek3_2mem_manager.py new file mode 100644 index 0000000000..fbf9f88c84 --- /dev/null +++ b/lightllm/common/kv_cache_mem_manager/deepseek3_2mem_manager.py @@ -0,0 +1,36 @@ +import torch +from typing import Any +from lightllm.common.kv_cache_mem_manager.deepseek2_mem_manager import Deepseek2MemoryManager + + +class Deepseek3_2MemoryManager(Deepseek2MemoryManager): + def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9): + assert dtype in [torch.bfloat16, torch.float16] + # 因为V3.2 使用了NSA 稀疏的缘故,所以其head_dim 会比原始的kv 多 128 + 4 = 132 个字节 (128 fp8 + 4byte float32 scale), + # 但是为了让整个数组具备16字节对齐,满足一些算子的约束,修改为添加 128 + 16 = 144 个字节, 这 144个字节中,后面132个字节用于 + # 存储真实数据,剩下12个,浪费了,只是占位。 + # 所以在子类中定制为其pad上,对外使用的接口,需要进行重载区别。 + super().__init__(size, dtype, head_num, head_dim + (144 // 2), layer_num, always_copy, mem_fraction) + + def copy_kv_to_mem_manager(self, layer_index: int, mem_index: torch.Tensor, kv: torch.Tensor): + """ + 将每一层生成的kv拷贝到mem manager对应mem_index 位置中 + """ + from ..basemodel.triton_kernel.kv_copy.mla_copy_kv import destindex_copy_kv + + rope_dim = 64 + kv_lora_rank = kv.shape[2] - rope_dim + assert kv_lora_rank + rope_dim == self.kv_buffer.shape[-1] - (144 // 2) + + destindex_copy_kv( + kv[:, :, :kv_lora_rank], + kv[:, :, kv_lora_rank:], + mem_index, + self.kv_buffer[layer_index][:, :, :kv_lora_rank], + self.kv_buffer[layer_index][:, :, kv_lora_rank : (kv_lora_rank + rope_dim)], + ) + return + + def get_att_input_params(self, layer_index: int) -> Any: + kv = self.kv_buffer[layer_index][:, :, : (self.head_dim - (144 // 2))] + return kv diff --git a/lightllm/common/kv_cache_mem_manager/mem_utils.py b/lightllm/common/kv_cache_mem_manager/mem_utils.py index 1ff58b89a0..b22590a6f2 100644 --- a/lightllm/common/kv_cache_mem_manager/mem_utils.py +++ b/lightllm/common/kv_cache_mem_manager/mem_utils.py @@ -5,6 +5,7 @@ PPLINT8KVMemoryManager, PPLINT4KVMemoryManager, Deepseek2MemoryManager, + Deepseek3_2MemoryManager, ) from lightllm.utils.log_utils import init_logger from lightllm.utils.envs_utils import get_env_start_args @@ -19,6 +20,14 @@ def select_mem_manager_class(): # case 1 # 先判断是否是 deepseek 系列的模型 model_class = get_llm_model_class() + + from lightllm.models import Deepseek3_2TpPartModel + + if issubclass(model_class, Deepseek3_2TpPartModel): + mem_class = Deepseek3_2MemoryManager + logger.info(f"Model kv cache using default, mem_manager class: {mem_class}") + return mem_class + from lightllm.models import Deepseek2TpPartModel if issubclass(model_class, Deepseek2TpPartModel): diff --git a/lightllm/models/__init__.py b/lightllm/models/__init__.py index 32ccbe8337..f2e29d4a88 100644 --- a/lightllm/models/__init__.py +++ b/lightllm/models/__init__.py @@ -18,6 +18,7 @@ from lightllm.models.gemma_2b.model import Gemma_2bTpPartModel from lightllm.models.phi3.model import Phi3TpPartModel from lightllm.models.deepseek2.model import Deepseek2TpPartModel +from lightllm.models.deepseek3_2.model import Deepseek3_2TpPartModel from lightllm.models.glm4_moe_lite.model import Glm4MoeLiteTpPartModel from lightllm.models.internvl.model import ( InternVLLlamaTpPartModel, diff --git a/lightllm/models/deepseek3_2/__init__.py b/lightllm/models/deepseek3_2/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/models/deepseek3_2/encoding_dsv32.py b/lightllm/models/deepseek3_2/encoding_dsv32.py new file mode 100644 index 0000000000..3ac4b83714 --- /dev/null +++ b/lightllm/models/deepseek3_2/encoding_dsv32.py @@ -0,0 +1,429 @@ +# Adapted from vLLM's deepseek_v32_encoding.py +# (https://github.com/vllm-project/vllm), which was originally adapted from +# https://huggingface.co/deepseek-ai/DeepSeek-V3.2/blob/main/encoding/encoding_dsv32.py +# +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import copy +import json +import re +from typing import Any + +# flake8: noqa: E501 +TOOLS_SYSTEM_TEMPLATE = """## Tools +You have access to a set of tools you can use to answer the user's question. +You can invoke functions by writing a "<{dsml_token}function_calls>" block like the following as part of your reply to the user: +<{dsml_token}function_calls> +<{dsml_token}invoke name="$FUNCTION_NAME"> +<{dsml_token}parameter name="$PARAMETER_NAME" string="true|false">$PARAMETER_VALUE +... + +<{dsml_token}invoke name="$FUNCTION_NAME2"> +... + + +String and scalar parameters should be specified as is without any escaping or quotes, while lists and objects should use JSON format. The "string" attribute should be set to "true" for string type parameters and "false" for other types (numbers, booleans, arrays, objects). +If the thinking_mode is enabled, then after function results you should strongly consider outputting a thinking block. Here is an example: +<{dsml_token}function_calls> +... + + +... + +{thinking_start_token}...thinking about results{thinking_end_token} +Here are the functions available in JSONSchema format: + +{tool_schemas} + +""" + +bos_token: str = "<|begin▁of▁sentence|>" +eos_token: str = "<|end▁of▁sentence|>" +thinking_start_token: str = "" +thinking_end_token: str = "" +dsml_token: str = "|DSML|" +system_msg_template: str = "{content}" +user_msg_template: str = "<|User|>{content}<|Assistant|>" +assistant_msg_template: str = "{reasoning}{content}{tool_calls}<|end▁of▁sentence|>" +thinking_template = "{reasoning}" + +response_format_template: str = ( + "## Response Format:\n\nYou MUST strictly adhere to the following schema to reply:\n{schema}" +) +tool_call_template: str = '<{dsml_token}invoke name="{name}">\n{arguments}\n' +tool_calls_template = "<{dsml_token}function_calls>\n{tool_calls}\n" + +tool_output_template: str = "\n{content}" + + +def to_json(value: Any) -> str: + try: + return json.dumps(value, ensure_ascii=False) + except Exception: + return json.dumps(value, ensure_ascii=True) + + +def tools_from_openai_format(tools): + return [tool["function"] for tool in tools] + + +def tool_calls_from_openai_format(tool_calls): + return [ + { + "name": tool_call["function"]["name"], + "arguments": tool_call["function"]["arguments"], + } + for tool_call in tool_calls + ] + + +def tool_calls_to_openai_format(tool_calls): + return [ + { + "type": "function", + "function": { + "name": tool_call["name"], + "arguments": tool_call["arguments"], + }, + } + for tool_call in tool_calls + ] + + +def encode_arguments_to_dsml(tool_call: dict) -> str: + p_dsml_template = """<{dsml_token}parameter name="{key}" string="{is_str}">{value}""" + P_dsml_strs = [] + if isinstance(tool_call["arguments"], str): + arguments = json.loads(tool_call["arguments"]) + else: + arguments = tool_call["arguments"] + + for k, v in arguments.items(): + p_dsml_str = p_dsml_template.format( + dsml_token=dsml_token, + key=k, + is_str="true" if isinstance(v, str) else "false", + value=v if isinstance(v, str) else to_json(v), + ) + + P_dsml_strs.append(p_dsml_str) + + return "\n".join(P_dsml_strs) + + +def decode_dsml_to_arguments(tool_name, tool_args): + def _decode_value(key, value, string): + if string == "true": + value = to_json(value) + return f"{to_json(key)}: {value}" + + tool_args_json = "{" + ", ".join([_decode_value(k, v, string=is_str) for k, (v, is_str) in tool_args.items()]) + "}" + return dict(name=tool_name, arguments=tool_args_json) + + +def render_tools(tools): + tools_json = [to_json(t) for t in tools] + + return TOOLS_SYSTEM_TEMPLATE.format( + tool_schemas="\n".join(tools_json), + dsml_token=dsml_token, + thinking_start_token=thinking_start_token, + thinking_end_token=thinking_end_token, + ) + + +def find_last_user_index(messages): + last_user_index = -1 + for idx in range(len(messages) - 1, -1, -1): + if messages[idx].get("role") in ["user", "developer"]: + last_user_index = idx + break + return last_user_index + + +def render_message(index, messages, thinking_mode): + if not (0 <= index < len(messages)): + raise ValueError(f"Index {index} out of range for messages list of length {len(messages)}") + if thinking_mode not in ["chat", "thinking"]: + raise ValueError(f"Invalid thinking_mode `{thinking_mode}`") + + prompt = "" + msg = messages[index] + last_user_idx = find_last_user_index(messages) + + role = msg.get("role") + content = msg.get("content") + tools = msg.get("tools") + response_format = msg.get("response_format") + tool_calls = msg.get("tool_calls") + reasoning = msg.get("reasoning") + is_prefix = msg.get("prefix", False) + + if tools: + tools = tools_from_openai_format(tools) + if tool_calls: + tool_calls = tool_calls_from_openai_format(tool_calls) + + if role == "system": + prompt += system_msg_template.format(content=content or "") + if tools: + prompt += "\n\n" + render_tools(tools) + + if response_format: + prompt += "\n\n" + response_format_template.format(schema=to_json(response_format)) + + elif role == "developer": + if not content: + raise ValueError(f"Invalid message for role `{role}`: {msg}") + content_developer = "" + if tools: + content_developer += "\n\n" + render_tools(tools) + + if response_format: + content_developer += "\n\n" + response_format_template.format(schema=to_json(response_format)) + + content_developer += "\n\n# The user's message is: {}".format(content) + + prompt += user_msg_template.format(content=content_developer) + if index == last_user_idx and thinking_mode == "thinking": + prompt += thinking_start_token + else: + prompt += thinking_end_token + + elif role == "user": + prompt += user_msg_template.format(content=content) + + if index == last_user_idx and thinking_mode == "thinking": + prompt += thinking_start_token + else: + prompt += thinking_end_token + + elif role == "tool": + prev_assistant_idx = index - 1 + assistant_msg = messages[prev_assistant_idx] + while prev_assistant_idx >= 0 and assistant_msg.get("role") == "tool": + prev_assistant_idx -= 1 + assistant_msg = messages[prev_assistant_idx] + + if not (index == 0 or prev_assistant_idx >= 0 and assistant_msg.get("role") == "assistant"): + raise ValueError(f"Invalid messages at {index}:\n{assistant_msg}") + + tool_call_order = index - prev_assistant_idx + assistant_tool_calls = assistant_msg.get("tool_calls") + if not (assistant_tool_calls and len(assistant_tool_calls) >= tool_call_order): + raise ValueError("No tool calls but found tool output") + + if tool_call_order == 1: + prompt += "\n\n" + + prompt += tool_output_template.format(content=content) + + if tool_call_order == len(assistant_tool_calls): + prompt += "\n" + + if index >= last_user_idx and thinking_mode == "thinking": + prompt += "\n\n" + thinking_start_token + else: + prompt += "\n\n" + thinking_end_token + + elif role == "assistant": + thinking_part = "" + + tool_calls_content = "" + if tool_calls: + tool_calls = [ + tool_call_template.format( + dsml_token=dsml_token, + name=tool_call.get("name"), + arguments=encode_arguments_to_dsml(tool_call), + ) + for tool_call in tool_calls + ] + tool_calls_content += "\n\n" + tool_calls_template.format( + dsml_token=dsml_token, tool_calls="\n".join(tool_calls) + ) + + summary_content = content or "" + + if thinking_mode == "thinking" and index > last_user_idx: + if not (reasoning or tool_calls): + raise ValueError( + f"ThinkingMode: {thinking_mode}, invalid message without reasoning/tool_calls `{msg}` after last user message" + ) + thinking_part = thinking_template.format(reasoning=reasoning or "") + thinking_end_token + + if not tool_calls and is_prefix: + prompt += summary_content + else: + prompt += assistant_msg_template.format( + reasoning=thinking_part, + content=summary_content, + tool_calls=tool_calls_content, + ) + else: + raise NotImplementedError(f"Unknown role: {role}") + + return prompt + + +def drop_thinking_messages(messages, last_user_idx=None): + messages_wo_thinking = [] + last_user_idx = find_last_user_index(messages) if last_user_idx is None else last_user_idx + for idx, msg in enumerate(messages): + role = msg.get("role") + if role in ["user", "system", "tool"] or idx >= last_user_idx: + messages_wo_thinking.append(msg) + continue + + elif role == "assistant": + msg_wo_thinking = copy.copy(msg) + msg_wo_thinking.pop("reasoning", None) + messages_wo_thinking.append(msg_wo_thinking) + + return messages_wo_thinking + + +def encode_messages( + messages, + thinking_mode, + context=None, + drop_thinking=True, + add_default_bos_token=True, +): + context = context if context else [] + full_messages = context + messages + + prompt = bos_token if add_default_bos_token and len(context) == 0 else "" + + if thinking_mode == "thinking" and drop_thinking: + full_messages = drop_thinking_messages(full_messages) + + for idx in range(len(messages)): + prompt += render_message(idx + len(context), full_messages, thinking_mode=thinking_mode) + + return prompt + + +def _read_until_stop(index, text, stop): + min_pos = len(text) + matched_stop = None + + for s in stop: + pos = text.find(s, index) + if pos != -1 and pos < min_pos: + min_pos = pos + matched_stop = s + + if matched_stop: + content = text[index:min_pos] + return min_pos + len(matched_stop), content, matched_stop + else: + content = text[index:] + return len(text), content, None + + +def parse_tool_calls(index, text): + tool_calls = [] + stop_token = None + tool_calls_end_token = f"" + + while index < len(text): + index, _, stop_token = _read_until_stop(index, text, [f"<{dsml_token}invoke", tool_calls_end_token]) + if _ != ">\n": + raise RuntimeError("Tool call format error") + + if stop_token == tool_calls_end_token: + break + + if stop_token is None: + raise RuntimeError("Missing special token") + + index, tool_name_content, stop_token = _read_until_stop( + index, text, [f"<{dsml_token}parameter", f"\n$', tool_name_content, flags=re.DOTALL) + if len(p_tool_name) != 1: + raise RuntimeError("Tool name format error") + tool_name = p_tool_name[0] + + tool_args = {} + while stop_token == f"<{dsml_token}parameter": + index, param_content, stop_token = _read_until_stop(index, text, [f"/{dsml_token}parameter"]) + + param_kv = re.findall( + r'^ name="(.*?)" string="(true|false)">(.*?)<$', + param_content, + flags=re.DOTALL, + ) + if len(param_kv) != 1: + raise RuntimeError("Parameter format error") + param_name, string, param_value = param_kv[0] + + if param_name in tool_args: + raise RuntimeError("Duplicate parameter name") + tool_args[param_name] = (param_value, string) + + index, content, stop_token = _read_until_stop( + index, text, [f"<{dsml_token}parameter", f"\n": + raise RuntimeError("Parameter format error") + + tool_call = decode_dsml_to_arguments(tool_name=tool_name, tool_args=tool_args) + tool_calls.append(tool_call) + + return index, stop_token, tool_calls + + +# NOTE: This function is designed to parse only correctly +# formatted string and will not attempt to correct malformed output +# that may be generated by the model. +def parse_message_from_completion_text(text, thinking_mode): + summary_content, reasoning, tool_calls = "", "", [] + index, stop_token = 0, None + tool_calls_start_token = f"\n\n<{dsml_token}function_calls" + + is_thinking, is_tool_calling = thinking_mode == "thinking", False + + if is_thinking: + index, content_delta, stop_token = _read_until_stop(index, text, [thinking_end_token, tool_calls_start_token]) + reasoning = content_delta + if stop_token != thinking_end_token: + raise RuntimeError("Invalid thinking format") + + index, content_delta, stop_token = _read_until_stop(index, text, [eos_token, tool_calls_start_token]) + summary_content = content_delta + if stop_token == tool_calls_start_token: + is_tool_calling = True + else: + if stop_token != eos_token: + raise RuntimeError("Invalid summary format") + + if is_tool_calling: + index, stop_token, tool_calls = parse_tool_calls(index, text) + + index, tool_ends_text, stop_token = _read_until_stop(index, text, [eos_token]) + if tool_ends_text: + raise RuntimeError("Unexpected content after tool calls") + + if not (len(text) == index and stop_token in [eos_token, None]): + raise RuntimeError("Unexpected content at end") + + for sp_token in [ + bos_token, + eos_token, + thinking_start_token, + thinking_end_token, + dsml_token, + ]: + if sp_token in summary_content or sp_token in reasoning: + raise RuntimeError("Unexpected special token in content") + + return { + "role": "assistant", + "content": summary_content, + "reasoning": reasoning, + "tool_calls": tool_calls_to_openai_format(tool_calls), + } diff --git a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py new file mode 100644 index 0000000000..49d8acb8d6 --- /dev/null +++ b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py @@ -0,0 +1,259 @@ +import torch +from typing import Union +from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo +from lightllm.models.deepseek2.layer_infer.transformer_layer_infer import Deepseek2TransformerLayerInfer +from lightllm.models.deepseek3_2.layer_weights.transformer_layer_weight import Deepseek3_2TransformerLayerWeight +from lightllm.common.basemodel.triton_kernel.norm.rmsnorm import rmsnorm_forward +from lightllm.models.deepseek2.triton_kernel.rotary_emb import rotary_emb_fwd +from lightllm.common.basemodel.attention.base_att import AttControl +from lightllm.common.basemodel.attention.nsa import NsaFlashMlaSparsePrefillAttState, NsaFlashMlaSparseDecodeAttState +from lightllm.models.deepseek3_2.triton_kernel.act_quant import act_quant +from lightllm.models.deepseek3_2.triton_kernel.destindex_copy_indexer_ks import destindex_copy_indexer_ks +from lightllm.models.deepseek3_2.triton_kernel.extract_indexer_ks import extract_indexer_ks +from lightllm.utils.envs_utils import get_env_start_args + + +class Deepseek3_2TransformerLayerInfer(Deepseek2TransformerLayerInfer): + def __init__(self, layer_num, network_config): + self.index_topk = network_config["index_topk"] + super().__init__(layer_num, network_config) + + self.indexer = NsaInfer(layer_idx=self.layer_num_, network_config=self.network_config_) + return + + def _get_qkv( + self, + input: torch.Tensor, + infer_state: Deepseek2InferStateInfo, + layer_weight: Deepseek3_2TransformerLayerWeight, + ) -> torch.Tensor: + input = input.view(-1, self.embed_dim_) + + q, cache_kv = layer_weight.qkv_a_proj_with_mqa_.mm(input).split( + [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1 + ) + q = rmsnorm_forward(q, weight=layer_weight.q_a_layernorm_.weight, eps=self.eps_) + + infer_state.get_topk_indices_params = { + "hidden_states": input, + "q_lora": q, + } + + q = layer_weight.q_b_proj_.mm(q) + cache_kv = cache_kv.view(-1, 1, self.kv_lora_rank + self.qk_rope_head_dim) + q = q.view(-1, self.tp_q_head_num_, self.qk_nope_head_dim + self.qk_rope_head_dim) + q_nope, q_rope = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + rmsnorm_forward( + cache_kv[:, :, : self.kv_lora_rank], + weight=layer_weight.kv_a_layernorm_.weight, + eps=self.eps_, + out=cache_kv[:, :, : self.kv_lora_rank], + ) + + rotary_emb_fwd( + q_rope, + cache_kv[:, :, self.kv_lora_rank :], + infer_state.position_cos, + infer_state.position_sin, + ) + return q, cache_kv + + def _context_attention_kernel( + self, + q: torch.Tensor, + kv, + infer_state: Deepseek2InferStateInfo, + layer_weight: Deepseek3_2TransformerLayerWeight, + out=None, + ) -> torch.Tensor: + # Model-specific q projection (uses layer weights) + q_nope, q_rope = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1) + q_all = torch.cat([q_nope, q_rope], dim=-1) + + # 计算 topk_indices + att_state = infer_state.prefill_att_state + topk_indices = self.indexer.get_indices( + hidden_states=infer_state.get_topk_indices_params["hidden_states"], + q_lora=infer_state.get_topk_indices_params["q_lora"], + infer_state=infer_state, + att_state=att_state, + layer_weight=layer_weight, + ) + del infer_state.get_topk_indices_params + + # Use NSA backend for attention computation + att_control = AttControl( + nsa_prefill=True, + nsa_prefill_dict={ + "topk_indices": topk_indices, + "softmax_scale": self.softmax_scale, + "kv_lora_rank": self.kv_lora_rank, + }, + ) + + mla_out = infer_state.prefill_att_state.prefill_att( + q=q_all, + k=infer_state.mem_manager.get_att_input_params(layer_index=self.layer_num_), + v=None, + att_control=att_control, + ) + return mla_out + + def _token_attention_kernel( + self, + q, + infer_state: Deepseek2InferStateInfo, + layer_weight: Deepseek3_2TransformerLayerWeight, + out=None, + ): + # Model-specific q projection (uses layer weights) + q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :] + q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1) + + # 计算 topk_indices + att_state = infer_state.decode_att_state + topk_indices = self.indexer.get_indices( + hidden_states=infer_state.get_topk_indices_params["hidden_states"], + q_lora=infer_state.get_topk_indices_params["q_lora"], + infer_state=infer_state, + att_state=att_state, + layer_weight=layer_weight, + ) + del infer_state.get_topk_indices_params + + # Use NSA backend for attention computation + att_control = AttControl( + nsa_decode=True, + nsa_decode_dict={ + "topk_indices": topk_indices, + "softmax_scale": self.softmax_scale, + "kv_lora_rank": self.kv_lora_rank, + "qk_rope_head_dim": self.qk_rope_head_dim, + }, + ) + + o_tensor = infer_state.decode_att_state.decode_att( + q=(q_nope, q_rope), + k=infer_state.mem_manager.get_att_input_params(layer_index=self.layer_num_), + v=None, + att_control=att_control, + ) + return o_tensor + + +class NsaInfer: + def __init__(self, layer_idx: int, network_config: dict): + super().__init__() + self.layer_idx_ = layer_idx + self.network_config_ = network_config + self.index_topk = network_config["index_topk"] + self.qk_nope_head_dim = network_config["qk_nope_head_dim"] + self.qk_rope_head_dim = network_config["qk_rope_head_dim"] + self.index_head_dim = network_config["index_head_dim"] + self.eps = network_config["rms_norm_eps"] + self.block_size = network_config["quantization_config"]["weight_block_size"][0] + self.scale_fmt = network_config["quantization_config"]["scale_fmt"] + self.softmax_scale = (self.index_head_dim) ** (-0.5) + self.index_n_heads = network_config["index_n_heads"] + self.index_n_heads_scale = (self.index_n_heads ** -0.5) * self.softmax_scale + + def get_indices( + self, + hidden_states: torch.Tensor, + q_lora: torch.Tensor, + infer_state: Deepseek2InferStateInfo, + att_state: Union[NsaFlashMlaSparsePrefillAttState, NsaFlashMlaSparseDecodeAttState], + layer_weight: Deepseek3_2TransformerLayerWeight, + ) -> torch.Tensor: + + q, k = self._get_q_k_bf16(hidden_states, q_lora, infer_state, layer_weight) + q_fp8, q_scale = act_quant(q, self.block_size, self.scale_fmt) + k_fp8, k_scale = act_quant(k, self.block_size, self.scale_fmt) + + destindex_copy_indexer_ks( + K_fp8=k_fp8, + K_scale=k_scale, + DestLoc=infer_state.mem_index, + O_buffer=infer_state.mem_manager.kv_buffer[self.layer_idx_].view(dtype=torch.uint8)[:, :, -132:], + ) + + weights = layer_weight.weights_proj_.mm(hidden_states) * self.index_n_heads_scale + weights = weights.unsqueeze(-1) * q_scale + + ks = att_state.ks + ke = att_state.ke + lengths = att_state.lengths + + if infer_state.is_prefill: + mtp_step = 0 + else: + mtp_step = get_env_start_args().mtp_step + # Use efficient Triton kernel to extract FP8 keys and scales from buffer + k_fp8_, k_scale_ = extract_indexer_ks( + I_buffer=infer_state.mem_manager.kv_buffer[self.layer_idx_].view(dtype=torch.uint8)[:, :, -132:], + b_seq_len=infer_state.b_seq_len, + b_req_idx=infer_state.b_req_idx, + req_to_token_indexs=infer_state.req_manager.req_to_token_indexs, + out_token_num=infer_state.b_seq_len.shape[0] * infer_state.max_kv_seq_len, + max_kv_seq_len=infer_state.max_kv_seq_len, + mtp_step=mtp_step, + ) + + import deep_gemm + + logits = deep_gemm.fp8_mqa_logits(q_fp8, (k_fp8_, k_scale_), weights.squeeze(-1), ks, ke) + + from sgl_kernel import fast_topk_v2 + + b_topk_index = fast_topk_v2( + score=logits, + lengths=lengths, + topk=self.index_topk, + row_starts=ks, + ) + # 将 topk index 转化为 mem index + + from ..triton_kernel.topk_index_to_mem_index import trans_topk_index_to_mem_index + + b_topk_index = trans_topk_index_to_mem_index( + topk_index=b_topk_index, + ragged_mem_index=att_state.ragged_mem_index, + ) + + return b_topk_index + + @staticmethod + def _rotate_activation(x: torch.Tensor) -> torch.Tensor: + assert x.dtype == torch.bfloat16 + from sgl_kernel import hadamard_transform + + hidden_size = x.size(-1) + assert (hidden_size & (hidden_size - 1)) == 0, "Hidden size must be a power of 2 for Hadamard transform." + return hadamard_transform(x, scale=hidden_size ** -0.5) + + def _get_q_k_bf16( + self, + hidden_states: torch.Tensor, + q_lora: torch.Tensor, + infer_state: Deepseek2InferStateInfo, + layer_weight: Deepseek3_2TransformerLayerWeight, + ): + q = layer_weight.wq_b_proj_.mm(q_lora).view(-1, self.index_n_heads, self.index_head_dim) + k = layer_weight.wk_proj_.mm(hidden_states) + + k = layer_weight.k_norm_(k, eps=self.eps) + + # 为什么 indexer 和主模型用的q k 的 rotary的排布方式不一样,这不是脱裤子放屁麻。 + from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd + + rotary_emb_fwd( + q[:, :, : self.qk_rope_head_dim], + k[:, None, : self.qk_rope_head_dim], + infer_state.position_cos, + infer_state.position_sin, + ) + + q = self._rotate_activation(q) + k = self._rotate_activation(k) + return q, k diff --git a/lightllm/models/deepseek3_2/layer_weights/transformer_layer_weight.py b/lightllm/models/deepseek3_2/layer_weights/transformer_layer_weight.py new file mode 100644 index 0000000000..eb14c82b49 --- /dev/null +++ b/lightllm/models/deepseek3_2/layer_weights/transformer_layer_weight.py @@ -0,0 +1,57 @@ +from lightllm.models.deepseek2.layer_weights.transformer_layer_weight import Deepseek2TransformerLayerWeight +from lightllm.common.basemodel.layer_weights.meta_weights import ROWMMWeight, LayerNormWeight + + +class Deepseek3_2TransformerLayerWeight(Deepseek2TransformerLayerWeight): + def __init__(self, layer_num, data_type, network_config, quant_cfg=None): + super().__init__(layer_num, data_type, network_config, quant_cfg) + return + + def _parse_config(self): + super()._parse_config() + self.q_lora_rank = self.network_config_["q_lora_rank"] + self.index_n_heads = self.network_config_["index_n_heads"] + self.index_head_dim = self.network_config_["index_head_dim"] + self.hidden_size = self.network_config_["hidden_size"] + + def _init_weight(self): + super()._init_weight() + self._init_indexer_weight() + + def _init_indexer_weight(self): + + prefix = f"model.layers.{self.layer_num_}.self_attn.indexer" + + self.wq_b_proj_ = ROWMMWeight( + in_dim=self.q_lora_rank, + out_dims=[self.index_n_heads * self.index_head_dim], + weight_names=f"{prefix}.wq_b.weight", + data_type=self.data_type_, + quant_method=None, + tp_rank=0, + tp_world_size=1, + ) + self.wk_proj_ = ROWMMWeight( + in_dim=self.hidden_size, + out_dims=[self.index_head_dim], + weight_names=f"{prefix}.wk.weight", + data_type=self.data_type_, + quant_method=None, + tp_rank=0, + tp_world_size=1, + ) + self.k_norm_ = LayerNormWeight( + dim=self.index_head_dim, + weight_name=f"{prefix}.k_norm.weight", + data_type=self.data_type_, + bias_name=f"{prefix}.k_norm.bias", + ) + self.weights_proj_ = ROWMMWeight( + in_dim=self.hidden_size, + out_dims=[self.index_n_heads], + weight_names=f"{prefix}.weights_proj.weight", + data_type=self.data_type_, + quant_method=None, + tp_rank=0, + tp_world_size=1, + ) diff --git a/lightllm/models/deepseek3_2/model.py b/lightllm/models/deepseek3_2/model.py new file mode 100644 index 0000000000..5831044311 --- /dev/null +++ b/lightllm/models/deepseek3_2/model.py @@ -0,0 +1,110 @@ +import copy +from lightllm.models.registry import ModelRegistry +from lightllm.models.deepseek2.model import Deepseek2TpPartModel +from lightllm.models.deepseek3_2.layer_weights.transformer_layer_weight import Deepseek3_2TransformerLayerWeight +from lightllm.models.deepseek3_2.layer_infer.transformer_layer_infer import Deepseek3_2TransformerLayerInfer +from lightllm.common.basemodel.attention import get_nsa_prefill_att_backend_class, get_nsa_decode_att_backend_class + + +@ModelRegistry(["deepseek_v32"]) +class Deepseek3_2TpPartModel(Deepseek2TpPartModel): + + # weight class + transformer_weight_class = Deepseek3_2TransformerLayerWeight + + # infer class + transformer_layer_infer_class = Deepseek3_2TransformerLayerInfer + + def _init_att_backend(self): + self.prefill_att_backend = get_nsa_prefill_att_backend_class(index=0)(model=self) + self.decode_att_backend = get_nsa_decode_att_backend_class(index=0)(model=self) + return + + +class DeepSeekV32Tokenizer: + """Tokenizer wrapper for DeepSeek-V3.2 that uses the Python-based + encoding_dsv32 module instead of Jinja chat templates. + + DeepSeek-V3.2's tokenizer_config.json does not ship with a Jinja chat + template, so ``apply_chat_template`` would fail without either a manually + supplied ``--chat_template`` file or this wrapper. + """ + + def __init__(self, tokenizer): + self.tokenizer = tokenizer + # Cache added vocabulary for performance (HuggingFace can be slow). + self._added_vocab = None + + # ------------------------------------------------------------------ + # Attribute delegation – everything not overridden goes to the inner + # tokenizer so that encode/decode/vocab_size/eos_token_id/… all work. + # ------------------------------------------------------------------ + def __getattr__(self, name): + return getattr(self.tokenizer, name) + + def get_added_vocab(self): + if self._added_vocab is None: + self._added_vocab = self.tokenizer.get_added_vocab() + return self._added_vocab + + # ------------------------------------------------------------------ + # Core override: route apply_chat_template through encode_messages. + # ------------------------------------------------------------------ + def apply_chat_template( + self, + conversation=None, + messages=None, + tools=None, + tokenize=False, + add_generation_prompt=True, + thinking=None, + **kwargs, + ): + from lightllm.models.deepseek3_2.encoding_dsv32 import encode_messages, render_tools + + msgs = conversation if conversation is not None else messages + if msgs is None: + raise ValueError("Either 'conversation' or 'messages' must be provided") + + # Deep copy to avoid mutating the caller's messages. + msgs = copy.deepcopy(msgs) + + # Determine thinking mode. + thinking_mode = "thinking" if thinking else "chat" + + # Inject tools into the first system message (or create one) so that + # encode_messages / render_message picks them up. + if tools: + # build_prompt passes tools as bare function dicts: + # [{"name": "f", "description": "...", "parameters": {...}}] + # encoding_dsv32's render_message expects OpenAI wrapper format: + # [{"type": "function", "function": {...}}] + wrapped_tools = [] + for t in tools: + if "function" in t: + wrapped_tools.append(t) + else: + wrapped_tools.append({"type": "function", "function": t}) + + injected = False + for msg in msgs: + if msg.get("role") == "system": + existing = msg.get("tools") or [] + msg["tools"] = existing + wrapped_tools + injected = True + break + + if not injected: + # Prepend a system message that carries the tools. + msgs.insert(0, {"role": "system", "content": "", "tools": wrapped_tools}) + + prompt = encode_messages( + msgs, + thinking_mode=thinking_mode, + drop_thinking=kwargs.get("drop_thinking", True), + add_default_bos_token=kwargs.get("add_default_bos_token", True), + ) + + if tokenize: + return self.tokenizer.encode(prompt, add_special_tokens=False) + return prompt diff --git a/lightllm/models/deepseek3_2/triton_kernel/__init__.py b/lightllm/models/deepseek3_2/triton_kernel/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/models/deepseek3_2/triton_kernel/act_quant.py b/lightllm/models/deepseek3_2/triton_kernel/act_quant.py new file mode 100644 index 0000000000..a4ecd0f518 --- /dev/null +++ b/lightllm/models/deepseek3_2/triton_kernel/act_quant.py @@ -0,0 +1,137 @@ +# Adapted from https://github.com/sgl-project/sglang/blob/ce6b17c0f94e6bf53633c8f324176a891e67fa7f/python/sglang/srt/layers/attention/nsa/triton_kernel.py +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + + +# Triton implementation +@triton.jit +def _act_quant_kernel( + X_ptr, + Y_ptr, + S_ptr, + M, + N, + group_size: tl.constexpr, + round_scale: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + """ + Triton kernel for activation quantization. + + Each block processes BLOCK_M rows and group_size columns. + """ + # Get block IDs + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + # FP8 constants + fp8_min = -448.0 + fp8_max = 448.0 + fp8_max_inv = 1.0 / fp8_max + + # Calculate row and column offsets + row_start = pid_m * BLOCK_M + col_start = pid_n * group_size + + # Create offset arrays + rows = row_start + tl.arange(0, BLOCK_M) + cols = col_start + tl.arange(0, BLOCK_N) + + # Mask for valid rows and columns + row_mask = rows < M + col_mask = cols < N + mask = row_mask[:, None] & col_mask[None, :] + + # Load input data + x_ptrs = X_ptr + rows[:, None] * N + cols[None, :] + x = tl.load(x_ptrs, mask=mask, other=0.0).to(tl.float32) + + # Compute absolute max along columns (group_size dimension) for each row + x_abs = tl.abs(x) + amax = tl.max(x_abs, axis=1) # Shape: (BLOCK_M,) + + # Clamp amax to avoid division by zero + amax = tl.maximum(amax, 1e-4) + + # Compute scale + if round_scale: + # Fast round scale using bit manipulation approximation + # This is a simplified version - the exact bit manipulation is harder in Triton + # Using log2 + ceil + pow2 as approximation + log_val = tl.log2(amax * fp8_max_inv) + log_ceil = tl.ceil(log_val) + scale = tl.exp2(log_ceil) + else: + scale = amax * fp8_max_inv + + # Quantize: y = clamp(x / scale, fp8_min, fp8_max) + scale_broadcast = scale[:, None] + y = x / scale_broadcast + y = tl.minimum(tl.maximum(y, fp8_min), fp8_max) + + # Store quantized output + y_ptrs = Y_ptr + rows[:, None] * N + cols[None, :] + tl.store(y_ptrs, y, mask=mask) + + # Store scales + s_cols = pid_n + s_ptrs = S_ptr + rows * (N // group_size) + s_cols + s_mask = row_mask + tl.store(s_ptrs, scale, mask=s_mask) + + +def act_quant( + x: torch.Tensor, block_size: int = 128, scale_fmt: Optional[str] = None +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Quantizes the input tensor `x` using block-wise quantization with Triton. + + Args: + x (torch.Tensor): The input tensor to be quantized. Must be contiguous and its last dimension size must be divisible by `block_size`. + block_size (int, optional): The size of the blocks to be used for quantization. Default is 128. + scale_fmt (Optional[str], optional): The format of the scale. Default is None. + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple containing: + - The quantized tensor with dtype `torch.float8_e4m3fn`. + - A tensor of scaling factors with dtype `torch.float32`. + """ + assert x.is_contiguous(), "Input tensor must be contiguous" + assert ( + x.size(-1) % block_size == 0 + ), f"Last dimension size must be divisible by block_size (block_size={block_size})" + + # Flatten all dims except last + N = x.size(-1) + x_flat = x.view(-1, N) + M = x_flat.size(0) + + # Allocate output tensors + y = torch.empty_like(x, dtype=torch.float8_e4m3fn) + y_flat = y.view(-1, N) + s = x.new_empty(*x.size()[:-1], N // block_size, dtype=torch.float32) + s_flat = s.view(-1, N // block_size) + + # Launch kernel + BLOCK_M = 32 + BLOCK_N = block_size + grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, block_size)) + round_scale = scale_fmt is not None + + _act_quant_kernel[grid]( + x_flat, + y_flat, + s_flat, + M, + N, + group_size=block_size, + round_scale=round_scale, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + num_stages=0 if round_scale else 2, + ) + + return y, s diff --git a/lightllm/models/deepseek3_2/triton_kernel/destindex_copy_indexer_ks.py b/lightllm/models/deepseek3_2/triton_kernel/destindex_copy_indexer_ks.py new file mode 100644 index 0000000000..a3115cc61c --- /dev/null +++ b/lightllm/models/deepseek3_2/triton_kernel/destindex_copy_indexer_ks.py @@ -0,0 +1,84 @@ +import torch + +import triton +import triton.language as tl + + +@triton.jit +def _fwd_kernel_destindex_copy_indexer_ks( + K_fp8, + K_scale, + DestLoc, + stride_k_bs, + stride_k_d, + stride_scale_bs, + stride_scale_d, + O_fp8, + stride_o_bs, + stride_o_d, + O_fp8_scale, + stride_o_scale_bs, + stride_o_scale_d, + BLOCK_DMODEL: tl.constexpr, +): + cur_index = tl.program_id(0) + offs_d = tl.arange(0, BLOCK_DMODEL) + + # Load destination index for this thread + dest_index = tl.load(DestLoc + cur_index).to(tl.int64) + + # Load K_fp8 (128 values) and K_scale (1 value) from source + k_fp8_ptrs = K_fp8 + cur_index * stride_k_bs + stride_k_d * offs_d + k_fp8 = tl.load(k_fp8_ptrs) + + k_scale = tl.load(K_scale + cur_index * stride_scale_bs + stride_scale_d * 0) + + o_k_ptrs = O_fp8 + dest_index * stride_o_bs + stride_o_d * offs_d + tl.store(o_k_ptrs, k_fp8) + + o_scale_ptr = O_fp8_scale + dest_index * stride_o_scale_bs + stride_o_scale_d * 0 + tl.store(o_scale_ptr, k_scale) + return + + +@torch.no_grad() +def destindex_copy_indexer_ks( + K_fp8: torch.Tensor, K_scale: torch.Tensor, DestLoc: torch.Tensor, O_buffer: torch.Tensor +): + seq_len = DestLoc.shape[0] + head_dim = K_fp8.shape[1] + + assert head_dim == 128, f"Expected head_dim=128, got {head_dim}" + assert O_buffer.shape[2] == 132, f"Expected O_buffer last dim=132, got {O_buffer.shape[2]}" + assert K_fp8.shape[0] == seq_len, f"Expected K_fp8 shape[0]={seq_len}, got {K_fp8.shape[0]}" + K_fp8 = K_fp8.view(-1, head_dim) + K_scale = K_scale.view(-1, 1) + + assert K_fp8.shape[0] == seq_len, f"Expected K_fp8 shape[0]={seq_len}, got {K_fp8.shape[0]}" + assert K_scale.shape[0] == seq_len, f"Expected K_scale shape[0]={seq_len}, got {K_scale.shape[0]}" + + O_fp8 = O_buffer[:, :, :128].view(dtype=torch.float8_e4m3fn).view(-1, head_dim) + O_fp8_scale = O_buffer[:, :, 128:132].view(dtype=torch.float32).view(-1, 1) + + grid = (seq_len,) + num_warps = 1 + + _fwd_kernel_destindex_copy_indexer_ks[grid]( + K_fp8=K_fp8, + K_scale=K_scale, + DestLoc=DestLoc, + stride_k_bs=K_fp8.stride(0), + stride_k_d=K_fp8.stride(1), + stride_scale_bs=K_scale.stride(0), + stride_scale_d=K_scale.stride(1), + O_fp8=O_fp8, + stride_o_bs=O_fp8.stride(0), + stride_o_d=O_fp8.stride(1), + O_fp8_scale=O_fp8_scale, + stride_o_scale_bs=O_fp8_scale.stride(0), + stride_o_scale_d=O_fp8_scale.stride(1), + BLOCK_DMODEL=head_dim, + num_warps=num_warps, + num_stages=1, + ) + return diff --git a/lightllm/models/deepseek3_2/triton_kernel/extract_indexer_ks.py b/lightllm/models/deepseek3_2/triton_kernel/extract_indexer_ks.py new file mode 100644 index 0000000000..d0f8b45f81 --- /dev/null +++ b/lightllm/models/deepseek3_2/triton_kernel/extract_indexer_ks.py @@ -0,0 +1,115 @@ +import torch + +import triton +import triton.language as tl + + +@triton.jit +def _fwd_kernel_extract_indexer_ks( + in_fp8, + stride_in_fp8_bs, + stride_in_fp8_h, + stride_in_fp8_d, + in_fp8_scale, + stride_in_scale_bs, + stride_in_scale_h, + stride_in_scale_d, + req_to_token_indexs, + stride_req_to_token_m, + stride_req_to_token_n, + b_seq_len, + b_req_idx, + O_fp8, + stride_o_fp8_bs, + stride_o_fp8_d, + O_scale, + stride_o_scale_bs, + stride_o_scale_d, + mtp_step, + BLOCK_DMODEL: tl.constexpr, + BLOCK_SEQ_LEN: tl.constexpr, +): + origin_cur_req_index = tl.program_id(0) + cur_req_index = (origin_cur_req_index + 1) * (mtp_step + 1) - 1 + token_start_index = tl.program_id(1) + cur_req_idx = tl.load(b_req_idx + cur_req_index) + cur_seq_len = tl.load(b_seq_len + cur_req_index) + offs_d = tl.arange(0, BLOCK_DMODEL) + b_seq_len = tl.load( + b_seq_len + (tl.arange(0, BLOCK_SEQ_LEN) + 1) * (mtp_step + 1) - 1, + mask=tl.arange(0, BLOCK_SEQ_LEN) < origin_cur_req_index, + other=0, + ) + store_start_index = tl.sum(b_seq_len) + + for i in range(token_start_index, cur_seq_len, tl.num_programs(1)): + mem_index = tl.load(req_to_token_indexs + cur_req_idx * stride_req_to_token_m + i * stride_req_to_token_n) + + in_fp8_ptrs = in_fp8 + mem_index * stride_in_fp8_bs + 0 * stride_in_fp8_h + stride_in_fp8_d * offs_d + kv_fp8 = tl.load(in_fp8_ptrs) + + in_scale_ptrs = in_fp8_scale + mem_index * stride_in_scale_bs + 0 * stride_in_scale_h + 0 * stride_in_scale_d + kv_scale = tl.load(in_scale_ptrs) + + o_fp8_ptrs = O_fp8 + (store_start_index + i) * stride_o_fp8_bs + stride_o_fp8_d * offs_d + tl.store(o_fp8_ptrs, kv_fp8) + + o_scale_ptr = O_scale + (store_start_index + i) * stride_o_scale_bs + tl.store(o_scale_ptr, kv_scale) + + return + + +@torch.no_grad() +def extract_indexer_ks( + I_buffer: torch.Tensor, + b_seq_len: torch.Tensor, + b_req_idx: torch.Tensor, + req_to_token_indexs: torch.Tensor, + out_token_num: int, + max_kv_seq_len: int, + mtp_step: int, +) -> tuple[torch.Tensor, torch.Tensor]: + head_dim = 128 + + assert I_buffer.dtype == torch.uint8, f"Expected I_buffer dtype=uint8, got {I_buffer.dtype}" + assert I_buffer.shape[2] == 132, f"Expected I_buffer last dim=132, got {I_buffer.shape[2]}" + in_fp8 = I_buffer[:, :, 0:128].view(dtype=torch.float8_e4m3fn) + in_fp8_scale = I_buffer[:, :, 128:132].view(dtype=torch.float32) + + # Allocate output tensors + O_fp8 = torch.empty((out_token_num // (mtp_step + 1), head_dim), dtype=torch.float8_e4m3fn, device=I_buffer.device) + O_scale = torch.empty((out_token_num // (mtp_step + 1), 1), dtype=torch.float32, device=I_buffer.device) + + assert b_seq_len.shape[0] % (mtp_step + 1) == 0 + grid = (b_seq_len.shape[0] // (mtp_step + 1), min(256, max_kv_seq_len)) + num_warps = 1 + + _fwd_kernel_extract_indexer_ks[grid]( + in_fp8, + stride_in_fp8_bs=in_fp8.stride(0), + stride_in_fp8_h=in_fp8.stride(1), + stride_in_fp8_d=in_fp8.stride(2), + in_fp8_scale=in_fp8_scale, + stride_in_scale_bs=in_fp8_scale.stride(0), + stride_in_scale_h=in_fp8_scale.stride(1), + stride_in_scale_d=in_fp8_scale.stride(2), + req_to_token_indexs=req_to_token_indexs, + stride_req_to_token_m=req_to_token_indexs.stride(0), + stride_req_to_token_n=req_to_token_indexs.stride(1), + b_seq_len=b_seq_len, + b_req_idx=b_req_idx, + O_fp8=O_fp8, + stride_o_fp8_bs=O_fp8.stride(0), + stride_o_fp8_d=O_fp8.stride(1), + O_scale=O_scale, + stride_o_scale_bs=O_scale.stride(0), + stride_o_scale_d=O_scale.stride(1), + mtp_step=mtp_step, + BLOCK_DMODEL=head_dim, + BLOCK_SEQ_LEN=triton.next_power_of_2(b_seq_len.shape[0] // (mtp_step + 1)), + num_warps=num_warps, + num_stages=1, + ) + + return O_fp8, O_scale diff --git a/lightllm/models/deepseek3_2/triton_kernel/fp8_mqa_logits.py b/lightllm/models/deepseek3_2/triton_kernel/fp8_mqa_logits.py new file mode 100644 index 0000000000..1c1f72b7d7 --- /dev/null +++ b/lightllm/models/deepseek3_2/triton_kernel/fp8_mqa_logits.py @@ -0,0 +1,141 @@ +import triton +import triton.language as tl +import torch + + +@triton.jit +def _fp8_paged_mqa_logits_kernel( + Q_ptr, + KV_ptr, + KVScale_ptr, + Weights_ptr, + MemIndex_ptr, + CuSeqlenKs_ptr, + CuSeqlenKe_ptr, + Output_ptr, + seq_len, + seq_len_kv, + num_heads, + head_dim, + stride_q_seq, + stride_q_head, + stride_q_dim, + stride_kv_pool, + stride_kv_dim, + stride_w_seq, + stride_w_head, + stride_o_seq, + stride_o_kv, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_D: tl.constexpr, +): + + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + start_m = pid_m * BLOCK_SIZE_M + start_n = pid_n * BLOCK_SIZE_N + + offs_m = start_m + tl.arange(0, BLOCK_SIZE_M) + offs_n = start_n + tl.arange(0, BLOCK_SIZE_N) + + logits = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + mask_m = offs_m < seq_len + mask_n = offs_n < seq_len_kv + + mem_indices = tl.load(MemIndex_ptr + offs_n, mask=mask_n, other=0) + + scales = tl.load(KVScale_ptr + mem_indices, mask=mask_n, other=1.0) + + for h in range(num_heads): + weights = tl.load(Weights_ptr + offs_m * stride_w_seq + h * stride_w_head, mask=mask_m, other=0.0) + score = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for d_block in range(tl.cdiv(head_dim, BLOCK_SIZE_D)): + d_start = d_block * BLOCK_SIZE_D + offs_d = d_start + tl.arange(0, BLOCK_SIZE_D) + mask_d = offs_d < head_dim + + q_ptrs = Q_ptr + offs_m[:, None] * stride_q_seq + h * stride_q_head + offs_d[None, :] * stride_q_dim + mask_q = (offs_m[:, None] < seq_len) & mask_d[None, :] + q = tl.load(q_ptrs, mask=mask_q, other=0.0).to(tl.float32) + + k_ptrs = KV_ptr + mem_indices[:, None] * stride_kv_pool + offs_d[None, :] * stride_kv_dim + mask_k = mask_n[:, None] & mask_d[None, :] + k = tl.load(k_ptrs, mask=mask_k, other=0.0).to(tl.float32) + + k = k * scales[:, None] + + score += tl.dot(q, tl.trans(k)) + score = tl.maximum(score, 0.0) + logits += score * weights[:, None] + + mask_ks = tl.load(CuSeqlenKs_ptr + offs_m, mask=mask_m, other=0) + mask_ke = tl.load(CuSeqlenKe_ptr + offs_m, mask=mask_m, other=seq_len_kv) + + mask_lo = offs_n[None, :] >= mask_ks[:, None] + mask_hi = offs_n[None, :] < mask_ke[:, None] + mask_valid = mask_lo & mask_hi & mask_m[:, None] & mask_n[None, :] + + logits = tl.where(mask_valid, logits, float("-inf")) + + # Store output + out_ptrs = Output_ptr + offs_m[:, None] * stride_o_seq + offs_n[None, :] * stride_o_kv + mask_out = (offs_m[:, None] < seq_len) & (offs_n[None, :] < seq_len_kv) + tl.store(out_ptrs, logits, mask=mask_out) + + +def fp8_paged_mqa_logits( + q: torch.Tensor, + kv: torch.Tensor, + kv_scale: torch.Tensor, + weights: torch.Tensor, + mem_index: torch.Tensor, + cu_seqlen_ks: torch.Tensor, + cu_seqlen_ke: torch.Tensor, + out: torch.Tensor = None, +) -> torch.Tensor: + seq_len, num_heads, head_dim = q.shape + seq_len_kv = mem_index.shape[0] + + if out is None: + output = torch.empty((seq_len, seq_len_kv), device=q.device, dtype=torch.float32) + else: + output = out + + BLOCK_SIZE_M = 16 + BLOCK_SIZE_N = 64 + BLOCK_SIZE_D = 128 + + grid = (triton.cdiv(seq_len, BLOCK_SIZE_M), triton.cdiv(seq_len_kv, BLOCK_SIZE_N)) + + _fp8_paged_mqa_logits_kernel[grid]( + q, + kv, + kv_scale, + weights, + mem_index, + cu_seqlen_ks, + cu_seqlen_ke, + output, + seq_len, + seq_len_kv, + num_heads, + head_dim, + q.stride(0), + q.stride(1), + q.stride(2), + kv.stride(0), + kv.stride(1), + weights.stride(0), + weights.stride(1), + output.stride(0), + output.stride(1), + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_D=BLOCK_SIZE_D, + ) + + return output diff --git a/lightllm/models/deepseek3_2/triton_kernel/token_group_quant.py b/lightllm/models/deepseek3_2/triton_kernel/token_group_quant.py new file mode 100644 index 0000000000..8079864133 --- /dev/null +++ b/lightllm/models/deepseek3_2/triton_kernel/token_group_quant.py @@ -0,0 +1,104 @@ +# Adapted from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/quantization/fp8_kernel.py + +import triton +import triton.language as tl +import torch +from typing import Tuple + +fp8_min = -448.0 +fp8_max = 448.0 +fp8_dtype = torch.float8_e4m3fn + + +@triton.jit +def _per_token_group_quant_mla_deep_gemm_masked_fp8( + y_ptr, + y_q_ptr, + y_s_ptr, + masked_m_ptr, + group_size, + y_stride_b, + y_stride_t, + y_q_stride_b, + y_q_stride_t, + y_s_stride_b, + y_s_stride_g, + eps, + fp8_min, + fp8_max, + NUM_GROUP: tl.constexpr, + BLOCK: tl.constexpr, +): + """A Triton-accelerated function to perform per-token-group + quantization on a tensor for deep_gemm grouped_gemm_masked. + This function converts the tensor values into float8 values. + y and y_q: (b, t, k) + y_s: (b, k//group_size, t) + """ + t_id = tl.program_id(0) + b_id = tl.program_id(1) + + y_ptr += b_id * y_stride_b + t_id * y_stride_t + y_q_ptr += b_id * y_q_stride_b + t_id * y_q_stride_t + y_s_ptr += b_id * y_s_stride_b + t_id + + if t_id == 0: + tl.store(masked_m_ptr + b_id, tl.num_programs(0)) + + cols = tl.arange(0, BLOCK) # group_size <= BLOCK + mask = cols < group_size + + for gid in range(NUM_GROUP): + y = tl.load(y_ptr + gid * group_size + cols, mask=mask, other=0.0).to(tl.float32) + _absmax = tl.maximum(tl.max(tl.abs(y)), eps) + y_s = _absmax / fp8_max + y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) + + tl.store(y_q_ptr + gid * group_size + cols, y_q, mask=mask) + tl.store(y_s_ptr + gid * y_s_stride_g, y_s) + + +def per_token_group_quant_mla_deep_gemm_masked_fp8( + x: torch.Tensor, + group_size: int = 128, + eps: float = 1e-12, + dtype: torch.dtype = fp8_dtype, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + This function quantizes input values to float8 values with per-token-group-quantization + for deep_gemm grouped_gemm_masked and specialized for mla absorbed case. + """ + assert x.dim() == 3, "`x` is not a 3d-tensor" + + b, m, k = x.shape + aligned_m = (m + 255) // 256 * 256 # 256 is the max block_m of the gemm kernel + num_tiles_k = k // group_size + assert num_tiles_k * group_size == k, f"k % {group_size} must be zero" + + x_q = x.new_empty((b, aligned_m, k), dtype=dtype) + x_s = x.new_empty((b, num_tiles_k, aligned_m), dtype=torch.float32) + masked_m = x.new_empty((b,), dtype=torch.int32) + + BLOCK_SIZE = triton.next_power_of_2(group_size) + grid = (m, b) + + _per_token_group_quant_mla_deep_gemm_masked_fp8[grid]( + x, + x_q, + x_s, + masked_m, + group_size, + x.stride(0), + x.stride(1), + x_q.stride(0), + x_q.stride(1), + x_s.stride(0), + x_s.stride(1), + eps, + -fp8_max, + fp8_max, + num_tiles_k, + BLOCK_SIZE, + ) + + return x_q, x_s.transpose(1, 2), masked_m, m, aligned_m diff --git a/lightllm/models/deepseek3_2/triton_kernel/topk_index_to_mem_index.py b/lightllm/models/deepseek3_2/triton_kernel/topk_index_to_mem_index.py new file mode 100644 index 0000000000..12786c6619 --- /dev/null +++ b/lightllm/models/deepseek3_2/triton_kernel/topk_index_to_mem_index.py @@ -0,0 +1,47 @@ +import torch + +import triton +import triton.language as tl + + +@triton.jit +def _trans_topk_index_to_mem_index( + topk_index, + topk_index_stride_b, + topk_index_stride_k, + ragged_mem_index, + topk_mem_index, + topk_mem_index_stride_b, + topk_mem_index_stride_k, + BLOCK_DMODEL: tl.constexpr, +): + cur_index = tl.program_id(0) + offs_d = tl.arange(0, BLOCK_DMODEL) + topk_index_ptrs = topk_index + cur_index * topk_index_stride_b + offs_d * topk_index_stride_k + topk_indices = tl.load(topk_index_ptrs) + + dest_mem_index = ragged_mem_index + topk_indices + mem_index = tl.load(dest_mem_index, mask=topk_indices != -1, other=-1) + tl.store(topk_mem_index + cur_index * topk_mem_index_stride_b + offs_d * topk_mem_index_stride_k, mem_index) + + +@torch.no_grad() +def trans_topk_index_to_mem_index(topk_index: torch.Tensor, ragged_mem_index: torch.Tensor): + assert topk_index.shape[1] == 2048, f"Expected topk_index shape[1]=2048, got {topk_index.shape[1]}" + + grid = (topk_index.shape[0],) + + topk_mem_index = torch.empty_like(topk_index) + + _trans_topk_index_to_mem_index[grid]( + topk_index=topk_index, + topk_index_stride_b=topk_index.stride(0), + topk_index_stride_k=topk_index.stride(1), + ragged_mem_index=ragged_mem_index, + topk_mem_index=topk_mem_index, + topk_mem_index_stride_b=topk_mem_index.stride(0), + topk_mem_index_stride_k=topk_mem_index.stride(1), + BLOCK_DMODEL=2048, + num_warps=8, + ) + return topk_mem_index diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 204bf83435..be24043538 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -128,7 +128,7 @@ def make_argument_parser() -> argparse.ArgumentParser: parser.add_argument( "--tool_call_parser", type=str, - choices=["qwen25", "llama3", "mistral", "deepseekv3", "qwen", "deepseekv31", "glm47", "kimi_k2"], + choices=["qwen25", "llama3", "mistral", "deepseekv3", "qwen", "deepseekv31", "deepseekv32", "glm47", "kimi_k2"], default=None, help="tool call parser type", ) diff --git a/lightllm/server/api_openai.py b/lightllm/server/api_openai.py index 11e24612b0..b06d8068ce 100644 --- a/lightllm/server/api_openai.py +++ b/lightllm/server/api_openai.py @@ -339,6 +339,7 @@ async def chat_completions_impl(request: ChatCompletionRequest, raw_request: Req # Streaming case async def stream_results() -> AsyncGenerator[bytes, None]: finish_reason = None + has_emitted_tool_calls = False from .req_id_generator import convert_sub_id_to_group_id prompt_tokens = 0 @@ -359,7 +360,7 @@ async def stream_results() -> AsyncGenerator[bytes, None]: if reasoning_text: choice_data = ChatCompletionStreamResponseChoice( index=0, - delta=DeltaMessage(reasoning_content=reasoning_text), + delta=DeltaMessage(role="assistant", reasoning_content=reasoning_text), finish_reason=None, ) chunk = ChatCompletionStreamResponse( @@ -368,7 +369,7 @@ async def stream_results() -> AsyncGenerator[bytes, None]: choices=[choice_data], model=request.model, ) - yield f"data: {chunk.model_dump_json()}\n\n" + yield f"data: {chunk.model_dump_json(exclude_none=True)}\n\n" if request.tool_choice != "none" and request.tools: if index not in parser_dict: @@ -387,8 +388,8 @@ async def stream_results() -> AsyncGenerator[bytes, None]: if normal_text: choice_data = ChatCompletionStreamResponseChoice( index=0, - delta=DeltaMessage(content=normal_text), - finish_reason=finish_reason if finish_reason else None, + delta=DeltaMessage(role="assistant", content=normal_text), + finish_reason=None, ) chunk = ChatCompletionStreamResponse( id=group_request_id, @@ -396,11 +397,12 @@ async def stream_results() -> AsyncGenerator[bytes, None]: choices=[choice_data], model=request.model, ) - yield f"data: {chunk.model_dump_json()}\n\n" + yield f"data: {chunk.model_dump_json(exclude_none=True)}\n\n" # 2) if we found calls, we output them as separate chunk(s) history_tool_calls_cnt = _get_history_tool_calls_cnt(request) for call_item in calls: + has_emitted_tool_calls = True # transform call_item -> FunctionResponse + ToolCall if finish_reason == "stop": latest_delta_len = 0 @@ -437,7 +439,7 @@ async def stream_results() -> AsyncGenerator[bytes, None]: choice_data = ChatCompletionStreamResponseChoice( index=0, delta=DeltaMessage(role="assistant", tool_calls=[tool_call]), - finish_reason="tool_calls", + finish_reason=None, ) chunk = ChatCompletionStreamResponse( id=group_request_id, @@ -445,24 +447,36 @@ async def stream_results() -> AsyncGenerator[bytes, None]: choices=[choice_data], model=request.model, ) - yield f"data: {chunk.model_dump_json()}\n\n" + yield f"data: {chunk.model_dump_json(exclude_none=True)}\n\n" else: - group_request_id = convert_sub_id_to_group_id(sub_req_id) - delta_message = DeltaMessage(role="assistant", content=delta) - if finish_status.is_finished(): - finish_reason = finish_status.get_finish_reason() - stream_choice = ChatCompletionStreamResponseChoice( - index=0, delta=delta_message, finish_reason=finish_reason - ) + stream_choice = ChatCompletionStreamResponseChoice(index=0, delta=delta_message, finish_reason=None) stream_resp = ChatCompletionStreamResponse( id=group_request_id, created=created_time, model=request.model, choices=[stream_choice], ) - yield ("data: " + json.dumps(stream_resp.dict(), ensure_ascii=False) + "\n\n").encode("utf-8") - # Additional usage chunk + yield f"data: {stream_resp.model_dump_json(exclude_none=True)}\n\n" + + # Determine final finish_reason: override to "tool_calls" if tool calls were emitted + if has_emitted_tool_calls and finish_reason == "stop": + finish_reason = "tool_calls" + + # Final empty chunk containing only finish_reason (and role) + if finish_reason is not None: + final_choice = ChatCompletionStreamResponseChoice( + index=0, + delta=DeltaMessage(), + finish_reason=finish_reason, + ) + final_chunk = ChatCompletionStreamResponse( + id=group_request_id, + created=created_time, + model=request.model, + choices=[final_choice], + ) + yield f"data: {final_chunk.model_dump_json(exclude_none=True)}\n\n" if request.stream_options and request.stream_options.include_usage: usage = UsageInfo( @@ -477,7 +491,7 @@ async def stream_results() -> AsyncGenerator[bytes, None]: model=request.model, usage=usage, ) - yield f"data: {usage_chunk.model_dump_json()}\n\n" + yield f"data: {usage_chunk.model_dump_json(exclude_none=True)}\n\n" background_tasks = BackgroundTasks() return StreamingResponse(stream_results(), media_type="text/event-stream", background=background_tasks) @@ -679,7 +693,7 @@ async def stream_results() -> AsyncGenerator[bytes, None]: model=request.model, usage=usage, ) - yield f"data: {usage_chunk.model_dump_json()}\n\n" + yield f"data: {usage_chunk.model_dump_json(exclude_none=True)}\n\n" background_tasks = BackgroundTasks() return StreamingResponse(stream_results(), media_type="text/event-stream", background=background_tasks) diff --git a/lightllm/server/function_call_parser.py b/lightllm/server/function_call_parser.py index 9214715b1d..3a8fddf744 100644 --- a/lightllm/server/function_call_parser.py +++ b/lightllm/server/function_call_parser.py @@ -29,7 +29,15 @@ logger = logging.getLogger(__name__) -TOOLS_TAG_LIST = ["<|plugin|>", "", "<|python_tag|>", "[TOOL_CALLS]", "<|tool▁calls▁begin|>"] +TOOLS_TAG_LIST = [ + "<|plugin|>", + "", + "<|python_tag|>", + "[TOOL_CALLS]", + "<|tool▁calls▁begin|>", + "<|DSML|function_calls>", +] class ToolCallItem(BaseModel): @@ -1443,6 +1451,272 @@ def parse_streaming_increment(self, new_text: str, tools: List[Tool]) -> Streami return StreamingParseResult(normal_text="", calls=calls) +class DeepSeekV32Detector(BaseFormatDetector): + """ + Detector for DeepSeek V3.2 model function call format using DSML + (DeepSeek Markup Language). + + Format Structure: + ``` + <|DSML|function_calls> + <|DSML|invoke name="get_weather"> + <|DSML|parameter name="location" string="true">Hangzhou + <|DSML|parameter name="date" string="true">2024-01-16 + + + ``` + + Key Components: + - Function Calls Block: `<|DSML|function_calls>` ... `` + - Individual Invocation: `<|DSML|invoke name="func">` ... `` + - Parameters: `<|DSML|parameter name="key" string="true|false">value` + - string="true": value is plain text (will be JSON-escaped) + - string="false": value is JSON (numbers, booleans, arrays, objects) + - Supports multiple parallel tool calls + + Reference: https://huggingface.co/deepseek-ai/DeepSeek-V3.2 + """ + + def __init__(self): + super().__init__() + self.dsml_token = "|DSML|" + self.bot_token = f"<{self.dsml_token}function_calls>" + self.eot_token = f"" + self.invoke_start_prefix = f"<{self.dsml_token}invoke" + self.invoke_end_token = f"" + self.param_end_token = f"" + + # Regex for complete invoke extraction + _de = re.escape(self.dsml_token) + self.invoke_regex = re.compile( + rf'<{_de}invoke\s+name="([^"]+)"\s*>(.*?)', + re.DOTALL, + ) + # Regex for parameter extraction + self.param_regex = re.compile( + rf'<{_de}parameter\s+name="([^"]+)"\s+string="(true|false)"\s*>(.*?)', + re.DOTALL, + ) + # Regex for partial invoke (name known, body still streaming) + self.partial_invoke_regex = re.compile( + rf'<{_de}invoke\s+name="([^"]+)"\s*>(.*)', + re.DOTALL, + ) + + self._last_arguments = "" + self._accumulated_params: List[tuple] = [] + self._in_function_calls = False # Track if we're inside a function_calls block + + def has_tool_call(self, text: str) -> bool: + return self.bot_token in text + + def _dsml_params_to_json(self, params: List[tuple]) -> str: + """Convert DSML parameter tuples (name, is_str, value) to a JSON arguments string.""" + args = {} + for name, is_str, value in params: + if is_str == "true": + args[name] = value + else: + try: + args[name] = json.loads(value) + except (json.JSONDecodeError, ValueError): + args[name] = value + return json.dumps(args, ensure_ascii=False) + + def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult: + """One-time parsing for DSML format tool calls.""" + idx = text.find(self.bot_token) + normal_text = text[:idx].strip() if idx != -1 else text + if self.bot_token not in text: + return StreamingParseResult(normal_text=normal_text, calls=[]) + + tool_indices = self._get_tool_indices(tools) + calls = [] + + invoke_matches = self.invoke_regex.findall(text) + for func_name, invoke_body in invoke_matches: + if func_name not in tool_indices: + logger.warning(f"Model attempted to call undefined function: {func_name}") + continue + + param_matches = self.param_regex.findall(invoke_body) + args_json = self._dsml_params_to_json(param_matches) + + calls.append( + ToolCallItem( + tool_index=tool_indices[func_name], + name=func_name, + parameters=args_json, + ) + ) + + return StreamingParseResult(normal_text=normal_text, calls=calls) + + def parse_streaming_increment(self, new_text: str, tools: List[Tool]) -> StreamingParseResult: + """Streaming incremental parsing for DSML format tool calls.""" + self._buffer += new_text + current_text = self._buffer + + # Check if we're inside a function_calls block or starting one + has_tool = self.has_tool_call(current_text) or self._in_function_calls + + if not has_tool: + partial_len = self._ends_with_partial_token(current_text, self.bot_token) + if partial_len: + return StreamingParseResult() + + self._buffer = "" + for e_token in [self.eot_token, self.invoke_end_token]: + if e_token in new_text: + new_text = new_text.replace(e_token, "") + return StreamingParseResult(normal_text=new_text) + + # Mark that we're inside a function_calls block + if self.has_tool_call(current_text): + self._in_function_calls = True + + # Check if function_calls block has ended + if self.eot_token in current_text: + self._in_function_calls = False + + if not hasattr(self, "_tool_indices"): + self._tool_indices = self._get_tool_indices(tools) + + calls: List[ToolCallItem] = [] + + try: + # Try to find complete invoke blocks first + complete_invoke_match = self.invoke_regex.search(current_text) + if complete_invoke_match: + func_name = complete_invoke_match.group(1) + invoke_body = complete_invoke_match.group(2) + + if self.current_tool_id == -1: + self.current_tool_id = 0 + self.prev_tool_call_arr = [] + self.streamed_args_for_tool = [""] + self._accumulated_params = [] + + while len(self.prev_tool_call_arr) <= self.current_tool_id: + self.prev_tool_call_arr.append({}) + while len(self.streamed_args_for_tool) <= self.current_tool_id: + self.streamed_args_for_tool.append("") + + param_matches = self.param_regex.findall(invoke_body) + args_json = self._dsml_params_to_json(param_matches) + + if not self.current_tool_name_sent: + calls.append( + ToolCallItem( + tool_index=self.current_tool_id, + name=func_name, + parameters="", + ) + ) + self.current_tool_name_sent = True + + # Send complete arguments (or remaining diff) + sent = len(self.streamed_args_for_tool[self.current_tool_id]) + argument_diff = args_json[sent:] + if argument_diff: + calls.append( + ToolCallItem( + tool_index=self.current_tool_id, + name=None, + parameters=argument_diff, + ) + ) + self.streamed_args_for_tool[self.current_tool_id] += argument_diff + + try: + self.prev_tool_call_arr[self.current_tool_id] = { + "name": func_name, + "arguments": json.loads(args_json), + } + except json.JSONDecodeError: + self.prev_tool_call_arr[self.current_tool_id] = { + "name": func_name, + "arguments": {}, + } + + # Remove processed invoke from buffer + invoke_end_pos = current_text.find(self.invoke_end_token, complete_invoke_match.start()) + if invoke_end_pos != -1: + self._buffer = current_text[invoke_end_pos + len(self.invoke_end_token) :] + else: + self._buffer = current_text[complete_invoke_match.end() :] + + self.current_tool_id += 1 + self._last_arguments = "" + self.current_tool_name_sent = False + self._accumulated_params = [] + self.streamed_args_for_tool.append("") + + return StreamingParseResult(normal_text="", calls=calls) + + # Partial invoke: name is known but parameters are still streaming + partial_match = self.partial_invoke_regex.search(current_text) + if partial_match: + func_name = partial_match.group(1) + partial_body = partial_match.group(2) + + if self.current_tool_id == -1: + self.current_tool_id = 0 + self.prev_tool_call_arr = [] + self.streamed_args_for_tool = [""] + self._accumulated_params = [] + + while len(self.prev_tool_call_arr) <= self.current_tool_id: + self.prev_tool_call_arr.append({}) + while len(self.streamed_args_for_tool) <= self.current_tool_id: + self.streamed_args_for_tool.append("") + + if not self.current_tool_name_sent: + if func_name in self._tool_indices: + calls.append( + ToolCallItem( + tool_index=self.current_tool_id, + name=func_name, + parameters="", + ) + ) + self.current_tool_name_sent = True + self.prev_tool_call_arr[self.current_tool_id] = { + "name": func_name, + "arguments": {}, + } + else: + # Stream arguments as complete parameters are parsed + param_matches = self.param_regex.findall(partial_body) + if param_matches and len(param_matches) > len(self._accumulated_params): + self._accumulated_params = param_matches + current_args_json = self._dsml_params_to_json(param_matches) + + sent = len(self.streamed_args_for_tool[self.current_tool_id]) + argument_diff = current_args_json[sent:] + + if argument_diff: + calls.append( + ToolCallItem( + tool_index=self.current_tool_id, + name=None, + parameters=argument_diff, + ) + ) + self.streamed_args_for_tool[self.current_tool_id] += argument_diff + + try: + self.prev_tool_call_arr[self.current_tool_id]["arguments"] = json.loads(current_args_json) + except json.JSONDecodeError: + pass + + return StreamingParseResult(normal_text="", calls=calls) + + except Exception as e: + logger.error(f"Error in DeepSeekV32 parse_streaming_increment: {e}") + return StreamingParseResult(normal_text="", calls=calls) + + class FunctionCallParser: """ Parser for function/tool calls in model outputs. @@ -1455,6 +1729,7 @@ class FunctionCallParser: ToolCallParserEnum: Dict[str, Type[BaseFormatDetector]] = { "deepseekv3": DeepSeekV3Detector, "deepseekv31": DeepSeekV31Detector, + "deepseekv32": DeepSeekV32Detector, "glm47": Glm47Detector, "kimi_k2": KimiK2Detector, "llama3": Llama32Detector, diff --git a/lightllm/server/tokenizer.py b/lightllm/server/tokenizer.py index 09bc938f23..2800bf0f6b 100644 --- a/lightllm/server/tokenizer.py +++ b/lightllm/server/tokenizer.py @@ -78,6 +78,17 @@ def get_tokenizer( model_cfg, _ = PretrainedConfig.get_config_dict(tokenizer_name) model_type = model_cfg.get("model_type", "") + # DeepSeek-V3.2 custom tokenizer mode: wraps the HF tokenizer with + # a Python-based apply_chat_template that uses encoding_dsv32.py. + if model_type == "deepseek_v32": + from ..models.deepseek3_2.model import DeepSeekV32Tokenizer + + hf_tokenizer = AutoTokenizer.from_pretrained( + tokenizer_name, trust_remote_code=trust_remote_code, *args, **kwargs + ) + logger.info("Using DeepSeek-V3.2 tokenizer mode with Python-based chat template encoding.") + return DeepSeekV32Tokenizer(hf_tokenizer) + if model_cfg["architectures"][0] == "TarsierForConditionalGeneration": from ..models.qwen2_vl.vision_process import Qwen2VLImageProcessor diff --git a/lightllm/utils/backend_validator.py b/lightllm/utils/backend_validator.py index 0e2f9c962e..6c5fe90309 100644 --- a/lightllm/utils/backend_validator.py +++ b/lightllm/utils/backend_validator.py @@ -187,6 +187,47 @@ def _validate_sdpa(): return True, None +def _validate_flashmla_sparse(): + """Validate flashmla_sparse (NSA) with ground truth.""" + try: + from sgl_kernel.flash_mla import flash_mla_sparse_fwd + + # need sgl-kernel version >= 0.3.21 and torch >= 2.9.1 + except Exception as e: + return False, f"sgl_kernel.flash_mla import failed: {type(e).__name__}: {e}" + + batch, heads, seq, dim = 1, 64, 128, 512 + 64 + dtype = torch.bfloat16 + device = "cuda" + + q = torch.randn(batch * seq, heads, dim, dtype=dtype, device=device) + kv = torch.zeros(batch * seq, 1, dim, dtype=dtype, device=device) + + index_topk = 128 + topk_indices = torch.zeros(batch * seq, index_topk, dtype=torch.int32, device=device) + for i in range(seq): + topk_indices[i, :] = torch.arange(index_topk, dtype=torch.int32, device=device) + + topk_indices = topk_indices.view(batch * seq, 1, index_topk) + + softmax_scale = 1.0 / (dim ** 0.5) + kv_lora_rank = dim + + try: + mla_out, _, _ = flash_mla_sparse_fwd( + q=q, + kv=kv, + indices=topk_indices, + sm_scale=softmax_scale, + d_v=kv_lora_rank, + ) + torch.cuda.synchronize() + except Exception as e: + return False, f"flash_mla_sparse_fwd run failed: {type(e).__name__}: {e}" + + return True, None + + def _run_in_subprocess(backend_name, pipe): """Run validation in subprocess with suppressed output.""" import sys @@ -207,6 +248,8 @@ def _run_in_subprocess(backend_name, pipe): success, err = _validate_flashinfer() elif backend_name == "triton": success, err = _validate_triton() + elif backend_name == "flashmla_sparse": + success, err = _validate_flashmla_sparse() else: success, err = False, f"Unknown backend: {backend_name}" pipe.send((success, err)) diff --git a/requirements.txt b/requirements.txt index 25cdab955d..5b0b201ae3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -59,7 +59,7 @@ six==1.16.0 sniffio==1.3.0 sortedcontainers==2.4.0 toolz==0.12.0 -torch==2.8.0 +torch==2.9.1 tqdm==4.65.0 transformers==4.57.1 tokenizers==0.22.1 @@ -80,16 +80,15 @@ frozendict==2.4.6 atomics==1.0.3 easydict==1.13 hypercorn==0.18.0 -flashinfer-python==0.2.4 -sgl-kernel==0.3.7.post1 +flashinfer-python==0.6.3 +sgl-kernel==0.3.21 httpx==0.28.1 librosa==0.11.0 cuda_bindings==12.9.0 orjson==3.11.2 setproctitle==1.3.6 -xformers==0.0.32.post1 xxhash==3.6.0 -torchvision==0.23.0 +torchvision==0.24.1 interegular==0.3.3 partial_json_parser==0.2.1.1.post6 websockets==15.0.1 diff --git a/test/acc/test_deepseekr1_mtp_ep.sh b/test/acc/test_deepseekr1_mtp_ep.sh index 7ceb1658c8..271176a571 100644 --- a/test/acc/test_deepseekr1_mtp_ep.sh +++ b/test/acc/test_deepseekr1_mtp_ep.sh @@ -1,3 +1,3 @@ -LOADWORKER=18 MOE_MODE=EP NUM_MAX_DISPATCH_TOKENS_PER_RANK=256 python -m lightllm.server.api_server --model_dir /mtc/models/DeepSeek-R1 --tp 8 --dp 8 --port 8089 --max_total_token_num 60000 --graph_max_batch_size 16 --batch_max_tokens 6000 --mtp_mode eagle_with_att --mtp_draft_model_dir /mtc/models/DeepSeek-R1-NextN --mtp_step 2 +LOADWORKER=18 NUM_MAX_DISPATCH_TOKENS_PER_RANK=256 python -m lightllm.server.api_server --enable_ep_moe --model_dir /mtc/sufubao/DeepSeek-V3.2 --tp 8 --dp 8 --port 8089 --max_total_token_num 60000 --graph_max_batch_size 16 --batch_max_tokens 6000 --mtp_mode eagle_with_att --mtp_draft_model_dir /mtc/models/DeepSeek-R1-NextN --mtp_step 2 HF_ALLOW_CODE_EVAL=1 HF_DATASETS_OFFLINE=0 lm_eval --model local-completions --model_args '{"model":"deepseek-ai/DeepSeek-R1", "base_url":"http://localhost:8089/v1/completions", "max_length": 16384}' --tasks gsm8k --batch_size 32 --confirm_run_unsafe_code \ No newline at end of file diff --git a/test/acc/test_deepseekv32_ep.sh b/test/acc/test_deepseekv32_ep.sh new file mode 100644 index 0000000000..a926478bc5 --- /dev/null +++ b/test/acc/test_deepseekv32_ep.sh @@ -0,0 +1,4 @@ +LOADWORKER=14 python -m lightllm.server.api_server --model_dir /mtc/sufubao/DeepSeek-V3.2 --tp 8 --graph_max_batch_size 32 --tool_call_parser deepseekv32 --mem_fraction 0.8 --reasoning_parser deepseek-v3 --dp 8 --enable_ep_moe + + +HF_ALLOW_CODE_EVAL=1 HF_DATASETS_OFFLINE=0 lm_eval --model local-completions --model_args '{"model":"deepseek-ai/DeepSeek-V3.2", "base_url":"http://localhost:8000/v1/completions", "max_length": 16384}' --tasks gsm8k --batch_size 500 --confirm_run_unsafe_code \ No newline at end of file diff --git a/test/chat_template/tool_chat_template_deepseekv32.jinjia b/test/chat_template/tool_chat_template_deepseekv32.jinjia index b6d239dce7..7bb0fc375f 100644 --- a/test/chat_template/tool_chat_template_deepseekv32.jinjia +++ b/test/chat_template/tool_chat_template_deepseekv32.jinjia @@ -1,101 +1,202 @@ -{% if not add_generation_prompt is defined %} - {% set add_generation_prompt = false %} -{% endif %} -{% if not thinking is defined %} - {% set thinking = false %} -{% endif %} -{% set ns = namespace(is_first=false, is_tool=false, system_prompt='', is_first_sp=true, is_last_user=false, is_only_sys=false, is_prefix=false) %} -{%- for message in messages %} - {%- if message['role'] == 'system' %} - {%- if ns.is_first_sp %} - {% set ns.system_prompt = ns.system_prompt + message['content'] %} - {% set ns.is_first_sp = false %} - {%- else %} - {% set ns.system_prompt = ns.system_prompt + '\n\n' + message['content'] %} - {%- endif %} - {% set ns.is_only_sys = true %} - {%- endif %} -{%- endfor %} - -{% if tools is defined and tools is not none %} - {% set tool_ns = namespace(text='## Tools\nYou have access to the following tools:\n') %} - {% for tool in tools %} - {% set tool_ns.text = tool_ns.text + '\n### ' + tool.function.name + '\nDescription: ' + tool.function.description + '\n\nParameters: ' + (tool.function.parameters | tojson) + '\n' %} - {% endfor %} - {% set tool_ns.text = tool_ns.text + "\nIMPORTANT: ALWAYS adhere to this exact format for tool use:\n<|tool▁calls▁begin|><|tool▁call▁begin|>tool_call_name<|tool▁sep|>tool_call_arguments<|tool▁call▁end|>{{additional_tool_calls}}<|tool▁calls▁end|>\n\nWhere:\n\n- `tool_call_name` must be an exact match to one of the available tools\n- `tool_call_arguments` must be valid JSON that strictly follows the tool's Parameters Schema\n- For multiple tool calls, chain them directly without separators or spaces\n" %} - {% set ns.system_prompt = ns.system_prompt + '\n\n' + tool_ns.text %} -{% endif %} - -{{ bos_token }}{{ ns.system_prompt }} -{%- for message in messages %} - {%- if message['role'] == 'user' %} - {%- set ns.is_tool = false -%} - {%- set ns.is_first = false -%} - {%- set ns.is_last_user = true -%} - {{'<|User|>' + message['content']}} - {%- endif %} - {%- if message['role'] == 'assistant' and message['tool_calls'] is defined and message['tool_calls'] is not none %} - {%- if ns.is_last_user or ns.is_only_sys %} - {{'<|Assistant|>'}} - {%- endif %} - {%- set ns.is_last_user = false -%} - {%- set ns.is_first = false %} - {%- set ns.is_tool = false -%} - {%- for tool in message['tool_calls'] %} - {%- set formatted_args = tool['function']['arguments'] if tool['function']['arguments'] is string else tool['function']['arguments']|tojson %} - {%- if not ns.is_first %} - {%- if message['content'] is none %} - {{'<|tool▁calls▁begin|><|tool▁call▁begin|>'+ tool['function']['name'] + '<|tool▁sep|>' + formatted_args + '<|tool▁call▁end|>'}} - {%- else %} - {{message['content'] + '<|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['function']['name'] + '<|tool▁sep|>' + formatted_args + '<|tool▁call▁end|>'}} - {%- endif %} - {%- set ns.is_first = true -%} - {%- else %} - {{'<|tool▁call▁begin|>'+ tool['function']['name'] + '<|tool▁sep|>' + formatted_args + '<|tool▁call▁end|>'}} - {%- endif %} - {%- endfor %} - {{'<|tool▁calls▁end|><|end▁of▁sentence|>'}} - {%- endif %} - {%- if message['role'] == 'assistant' and (message['tool_calls'] is not defined or message['tool_calls'] is none) %} - {%- if ns.is_last_user %} - {{'<|Assistant|>'}} - {%- if message['prefix'] is defined and message['prefix'] and thinking %} - {{''}} - {%- else %} - {{''}} - {%- endif %} - {%- endif %} - {%- if message['prefix'] is defined and message['prefix'] %} - {%- set ns.is_prefix = true -%} - {%- endif %} - {%- set ns.is_last_user = false -%} - {%- if ns.is_tool %} - {{message['content'] + '<|end▁of▁sentence|>'}} - {%- set ns.is_tool = false -%} - {%- else %} - {%- set content = message['content'] -%} - {%- if '' in content %} - {%- set content = content.split('', 1)[1] -%} - {%- endif %} - {{content + '<|end▁of▁sentence|>'}} - {%- endif %} - {%- endif %} - {%- if message['role'] == 'tool' %} - {%- set ns.is_last_user = false -%} - {%- set ns.is_tool = true -%} - {{'<|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}} - {%- endif %} - {%- if message['role'] != 'system' %} - {% set ns.is_only_sys = false %} - {%- endif %} +{#- ============================================================================ + DeepSeek-V3.2 DSML Chat Template + Converted from encoding_dsv32.py encode_messages function. + Uses DSML (DeepSeek Markup Language) format for tool calls. + ============================================================================ -#} +{%- set bos_token = "<|begin▁of▁sentence|>" -%} +{%- set eos_token = "<|end▁of▁sentence|>" -%} +{%- set thinking_start_token = "" -%} +{%- set thinking_end_token = "" -%} +{%- set dsml_token = "|DSML|" -%} + +{%- set system_msg_template = "{content}" -%} +{%- set user_msg_template = "<|User|>{content}<|Assistant|>" -%} +{%- set assistant_msg_template = "{reasoning}{content}{tool_calls}<|end▁of▁sentence|>" -%} +{%- set thinking_template = "{reasoning_content}" -%} +{%- set tool_call_template = "<{dsml_token}invoke name=\"{name}\">\n{arguments}\n{dsml_token}invoke>" -%} +{%- set tool_calls_template = "<{dsml_token}function_calls>\n{tool_calls}\n{dsml_token}function_calls>" -%} +{%- set tool_output_template = "\n{content}" -%} + +{%- set TOOLS_SYSTEM_TEMPLATE -%} +## Tools +You have access to a set of tools you can use to answer the user's question. +You can invoke functions by writing a "<{{ dsml_token }}function_calls>" block like the following as part of your reply to the user: + +<{{ dsml_token }}function_calls> +<{{ dsml_token }}invoke name="$FUNCTION_NAME"> +<{{ dsml_token }}parameter name="$PARAMETER_NAME" string="true|false">$PARAMETER_VALUE{{ dsml_token }}parameter> +... +{{ dsml_token }}invoke> +<{{ dsml_token }}invoke name="$FUNCTION_NAME2"> +... +{{ dsml_token }}invoke> +{{ dsml_token }}function_calls> + +String and scalar parameters should be specified as is without any escaping or quotes, while lists and objects should use JSON format. The "string" attribute should be set to "true" for string type parameters and "false" for other types (numbers, booleans, arrays, objects). + +Here are the functions available in JSONSchema format: +{tool_schemas} +{%- endset -%} + +{%- if thinking_mode is not defined -%} + {%- set thinking_mode = "thinking" -%} +{%- endif -%} +{%- if drop_thinking is not defined -%} + {%- set drop_thinking = true -%} +{%- endif -%} +{%- if add_default_bos_token is not defined -%} + {%- set add_default_bos_token = true -%} +{%- endif -%} + +{#- Macro: encode_arguments_to_dsml -#} +{%- macro encode_arguments_to_dsml(arguments) -%} + {%- set ns = namespace(P_dsml_strs=[]) -%} + {%- if arguments is mapping -%} + {%- for k, v in arguments.items() -%} + {%- if v is string -%} + {%- set is_str = "true" -%} + {%- set value = v -%} + {%- else -%} + {%- set is_str = "false" -%} + {%- set value = v | tojson -%} + {%- endif -%} + {%- set p_dsml_str = "<" ~ dsml_token ~ "parameter name=\"" ~ k ~ "\" string=\"" ~ is_str ~ "\">" ~ value ~ dsml_token ~ "parameter>" -%} + {%- set ns.P_dsml_strs = ns.P_dsml_strs + [p_dsml_str] -%} + {%- endfor -%} + {%- endif -%} + {{- ns.P_dsml_strs | join("\n") -}} +{%- endmacro -%} + +{#- Macro: render_tools -#} +{%- macro render_tools(tools) -%} + {%- set ns = namespace(tools_json=[]) -%} + {%- for tool in tools -%} + {%- if tool.function is defined -%} + {%- set ns.tools_json = ns.tools_json + [tool.function | tojson] -%} + {%- else -%} + {%- set ns.tools_json = ns.tools_json + [tool | tojson] -%} + {%- endif -%} + {%- endfor -%} + {{- TOOLS_SYSTEM_TEMPLATE | replace("{tool_schemas}", ns.tools_json | join("\n")) }} +{% endmacro -%} + +{#- Macro: find_last_user_index -#} +{%- macro find_last_user_index(messages) -%} + {%- set ns = namespace(last_user_index=-1) -%} + {%- for msg in messages -%} + {%- set role = msg.role if msg.role is defined else msg.get('role') -%} + {%- if role in ['user', 'developer'] -%} + {%- set ns.last_user_index = loop.index0 -%} + {%- endif -%} + {%- endfor -%} + {{- ns.last_user_index -}} +{%- endmacro -%} + +{#- Macro: render_tool_calls_content -#} +{%- macro render_tool_calls_content(tool_calls) -%} + {%- set ns = namespace(formatted_calls=[]) -%} + {%- for tool_call in tool_calls -%} + {%- if tool_call.function is defined -%} + {%- set name = tool_call.function.name -%} + {%- set arguments = tool_call.function.arguments -%} + {%- else -%} + {%- set name = tool_call.name -%} + {%- set arguments = tool_call.arguments -%} + {%- endif -%} + {%- if arguments is string -%} + {%- set arguments = arguments | fromjson -%} + {%- endif -%} + {%- set formatted_call = "<" ~ dsml_token ~ "invoke name=\"" ~ name ~ "\">\n" ~ encode_arguments_to_dsml(arguments) ~ "\n" ~ dsml_token ~ "invoke>" -%} + {%- set ns.formatted_calls = ns.formatted_calls + [formatted_call] -%} + {%- endfor -%} + {{- "<" ~ dsml_token ~ "function_calls>\n" ~ ns.formatted_calls | join("\n") ~ "\n" ~ dsml_token ~ "function_calls>" -}} +{%- endmacro -%} + +{#- Macro: render_message -#} +{%- macro render_message(index, messages, thinking_mode) -%} + {%- set msg = messages[index] -%} + {%- set last_user_idx = find_last_user_index(messages) | int -%} + {%- set role = msg.role if msg.role is defined else msg.get('role') -%} + {%- set content = msg.content if msg.content is defined else (msg.get('content', '') or '') -%} + {%- set msg_tools = msg.tools if msg.tools is defined else msg.get('tools', []) -%} + {%- set tool_calls = msg.tool_calls if msg.tool_calls is defined else msg.get('tool_calls', []) -%} + {%- set reasoning_content = msg.reasoning_content if msg.reasoning_content is defined else (msg.get('reasoning_content', '') or '') -%} + + {%- if role == 'system' -%} + {{- content or '' -}} + {%- if msg_tools -%} + {{- "\n\n" ~ render_tools(msg_tools) -}} + {%- endif -%} + + {%- elif role == 'user' -%} + {{- "<|User|>" ~ content ~ "<|Assistant|>" -}} + {%- if index == last_user_idx and thinking_mode == "thinking" -%} + {{- thinking_start_token -}} + {%- else -%} + {{- thinking_end_token -}} + {%- endif -%} + + {%- elif role == 'tool' -%} + {%- set ns = namespace(prev_assistant_idx=-1) -%} + {%- for i in range(index - 1, -1, -1) -%} + {%- set check_role = messages[i].role if messages[i].role is defined else messages[i].get('role') -%} + {%- if check_role != 'tool' and ns.prev_assistant_idx == -1 -%} + {%- set ns.prev_assistant_idx = i -%} + {%- endif -%} + {%- endfor -%} + {%- set tool_call_order = index - ns.prev_assistant_idx -%} + {%- set assistant_msg = messages[ns.prev_assistant_idx] -%} + {%- set assistant_tool_calls = assistant_msg.tool_calls if assistant_msg.tool_calls is defined else assistant_msg.get('tool_calls', []) -%} + {%- if tool_call_order == 1 -%} + {{- "\n\n" -}} + {%- endif -%} + {{- "\n" ~ content -}} + {%- if tool_call_order == (assistant_tool_calls | length) -%} + {{- "\n" -}} + {%- if index >= last_user_idx and thinking_mode == "thinking" -%} + {{- "\n\n" ~ thinking_start_token -}} + {%- else -%} + {{- "\n\n" ~ thinking_end_token -}} + {%- endif -%} + {%- endif -%} + + {%- elif role == 'assistant' -%} + {%- set ns = namespace(thinking_part="", tool_calls_content="") -%} + {%- if tool_calls -%} + {%- set ns.tool_calls_content = "\n\n" ~ render_tool_calls_content(tool_calls) -%} + {%- endif -%} + {%- set summary_content = content or "" -%} + {%- if thinking_mode == "thinking" and index > last_user_idx -%} + {%- set ns.thinking_part = reasoning_content ~ thinking_end_token -%} + {%- endif -%} + {{- ns.thinking_part ~ summary_content ~ ns.tool_calls_content ~ "<|end▁of▁sentence|>" -}} + {%- endif -%} +{%- endmacro -%} + +{#- Main template body -#} +{%- set full_messages = messages -%} + +{#- Handle tools in top-level (OpenAI format) -#} +{%- if tools is defined and tools is not none -%} + {%- set ns_sys = namespace(has_system=false, sys_idx=-1) -%} + {%- for msg in full_messages -%} + {%- set role = msg.role if msg.role is defined else msg.get('role') -%} + {%- if role == 'system' and not ns_sys.has_system -%} + {%- set ns_sys.has_system = true -%} + {%- set ns_sys.sys_idx = loop.index0 -%} + {%- endif -%} + {%- endfor -%} +{%- endif -%} + +{%- if add_default_bos_token -%} + {{- bos_token -}} +{%- endif -%} + +{#- If tools defined at top level but no system message has them, prepend tools info -#} +{%- if tools is defined and tools is not none -%} + {{- render_tools(tools) -}} +{%- endif -%} + +{%- for msg in full_messages -%} + {{- render_message(loop.index0, full_messages, thinking_mode) -}} {%- endfor -%} -{% if add_generation_prompt and not ns.is_tool%} - {% if ns.is_last_user or ns.is_only_sys or not ns.is_prefix %} - {{'<|Assistant|>'}} - {%- if not thinking %} - {{''}} - {%- else %} - {{''}} - {%- endif %} - {% endif %} -{% endif %} diff --git a/test/test_api/test_gsmk.py b/test/test_api/test_gsmk.py new file mode 100644 index 0000000000..2d9ead65b8 --- /dev/null +++ b/test/test_api/test_gsmk.py @@ -0,0 +1,265 @@ +# Adapted from https://github.com/sgl-project/sglang/blob/main/benchmark/gsm8k/bench_other.py +import argparse +import ast +import json +import os +import re +import time +from concurrent.futures import ThreadPoolExecutor +from typing import Optional + +import numpy as np +import requests +from tqdm import tqdm + +INVALID = -9999999 + +SYSTEM_PROMPT_TARGET_LEN = 18192 + + +def generate_system_prompt(): + """Generate a system prompt of approximately 8192 characters.""" + base = ( + "You are a highly capable math assistant. Your task is to solve grade school math problems step by step. " + "Show your reasoning clearly and provide the final numerical answer. " + "Break down each problem into smaller steps and verify your calculations. " + "Always end your answer with the format: #### . " + ) + # Repeat base text to reach target length + repeats = SYSTEM_PROMPT_TARGET_LEN // len(base) + 1 + prompt = (base * repeats)[:SYSTEM_PROMPT_TARGET_LEN] + return prompt + + +def read_jsonl(filename: str): + """Read a JSONL file.""" + with open(filename) as fin: + for line in fin: + if line.startswith("#"): + continue + yield json.loads(line) + + +def dump_state_text(filename: str, states: list, mode: str = "w"): + """Dump program state in a text file.""" + with open(filename, mode) as fout: + for i, s in enumerate(states): + if isinstance(s, str): + fout.write(f"==== {i} ====\n{s}\n") + else: + fout.write(f"==== {i} ====\n{str(s)}\n") + + +def download_and_cache_file(url: str, filename: Optional[str] = None): + """Read and cache a file from a url.""" + if filename is None: + filename = os.path.join("/tmp", url.split("/")[-1]) + + # Check if the cache file already exists + if os.path.exists(filename): + return filename + + print(f"Downloading from {url} to {filename}") + + # Stream the response to show the progress bar + response = requests.get(url, stream=True) + response.raise_for_status() # Check for request errors + + # Total size of the file in bytes + total_size = int(response.headers.get("content-length", 0)) + chunk_size = 1024 # Download in chunks of 1KB + + # Use tqdm to display the progress bar + with open(filename, "wb") as file, tqdm( + desc="Downloading", + total=total_size, + unit="iB", + unit_scale=True, + unit_divisor=1024, + ) as bar: + for chunk in response.iter_content(chunk_size=chunk_size): + size = file.write(chunk) + bar.update(size) + + return filename + + +def call_generate_lightllm(prompt, temperature, max_tokens, stop=None, url=None): + """Call LightLLM API for text generation.""" + assert url is not None + + data = { + "inputs": prompt, + "parameters": { + "temperature": temperature, + "max_new_tokens": max_tokens, + "stop_sequences": stop, + "repetition_penalty": 1.0, + "top_p": 1.0, + "top_k": 1, + }, + } + res = requests.post(url, json=data) + assert res.status_code == 200, f"API request failed with status code {res.status_code}: {res.text}" + + response_json = res.json() + if "generated_text" not in response_json: + raise ValueError(f"Invalid API response format. Expected 'generated_text' key, got: {response_json.keys()}") + if not isinstance(response_json["generated_text"], list) or len(response_json["generated_text"]) == 0: + raise ValueError( + "Invalid API response format. 'generated_text' should be a non-empty list, " + f"got: {response_json['generated_text']}" + ) + + pred = response_json["generated_text"][0] + return pred + + +def get_one_example(lines, i, include_answer): + ret = "Question: " + lines[i]["question"] + "\nAnswer:" + if include_answer: + ret += " " + lines[i]["answer"] + return ret + + +def get_few_shot_examples(lines, k): + ret = "" + for i in range(k): + ret += get_one_example(lines, i, True) + "\n\n" + return ret + + +def get_answer_value(answer_str): + answer_str = answer_str.replace(",", "") + # First try to find the answer after "####" marker (GSM8K format) + match = re.search(r"####\s*(-?\d+)", answer_str) + if match: + try: + return ast.literal_eval(match.group(1)) + except SyntaxError: + pass + # Fallback: find all numbers and take the last one + numbers = re.findall(r"\d+", answer_str) + if len(numbers) < 1: + return INVALID + try: + return ast.literal_eval(numbers[-1]) + except SyntaxError: + return INVALID + + +def parse_args(): + """Parse command line arguments.""" + parser = argparse.ArgumentParser() + parser.add_argument("--parallel", type=int, default=256) + parser.add_argument("--host", type=str, default="http://127.0.0.1") + parser.add_argument("--port", type=int, default=8000) + parser.add_argument("--num-shots", type=int, default=5) + parser.add_argument("--num-questions", type=int, default=200) + parser.add_argument("--result-file", type=str, default="result.jsonl") + parser.add_argument("--data-path", type=str, default="test.jsonl") + parser.add_argument( + "--system-prompt", action="store_true", help="Prepend an 8192-character system prompt to each request" + ) + return parser.parse_args() + + +def main(args): + # LightLLM API URL + url = f"{args.host}:{args.port}/generate" + + # Read data + url_data = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl" + filename = download_and_cache_file(url_data) + lines = list(read_jsonl(filename)) + + # Construct prompts + num_questions = args.num_questions + num_shots = args.num_shots + few_shot_examples = get_few_shot_examples(lines, num_shots) + + system_prefix = "" + if args.system_prompt: + system_prefix = generate_system_prompt() + "\n\n" + print(f"System prompt enabled: {len(system_prefix)} characters") + + # Ensure we have enough samples and avoid data leakage + # Test questions should start after few-shot examples + max_available = len(lines) - num_shots + if num_questions > max_available: + print( + "Warning: Requested {} questions, but only {} available after reserving {} for few-shot. " + "Using {} questions.".format(num_questions, max_available, num_shots, max_available) + ) + num_questions = max_available + + questions = [] + labels = [] + for i in range(num_shots, num_shots + num_questions): + questions.append(get_one_example(lines, i, False)) + labels.append(get_answer_value(lines[i]["answer"])) + assert all(label != INVALID for label in labels) + + states = [None] * len(labels) + + # Run requests using thread pool + def get_one_answer(i): + answer = call_generate_lightllm( + prompt=system_prefix + few_shot_examples + questions[i], + temperature=0, + max_tokens=1024, + stop=["Question", "Assistant:", "<|separator|>", "Human:", "\n\nQuestion"], + url=url, + ) + states[i] = answer + + tic = time.perf_counter() + if args.parallel == 1: + for i in tqdm(range(len(questions))): + get_one_answer(i) + else: + with ThreadPoolExecutor(args.parallel) as executor: + list( + tqdm( + executor.map(get_one_answer, list(range(len(questions)))), + total=len(questions), + ) + ) + + latency = time.perf_counter() - tic + + preds = [] + for i in range(len(states)): + preds.append(get_answer_value(states[i])) + + # Compute accuracy + acc = np.mean(np.array(preds) == np.array(labels)) + invalid = np.mean(np.array(preds) == INVALID) + + # Print results + print(f"Accuracy: {acc:.3f}") + print(f"Invalid: {invalid:.3f}") + print(f"Latency: {latency:.3f} s") + + # Dump results + dump_state_text("tmp_output_lightllm.txt", states) + + with open(args.result_file, "a") as fout: + value = { + "task": "gsm8k", + "backend": "lightllm", + "num_gpus": 1, + "latency": round(latency, 3), + "accuracy": round(acc, 3), + "num_requests": args.num_questions, + "other": { + "num_questions": args.num_questions, + "parallel": args.parallel, + }, + } + fout.write(json.dumps(value) + "\n") + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/unit_tests/common/basemodel/triton_kernel/test_gen_nsa_ks_ke.py b/unit_tests/common/basemodel/triton_kernel/test_gen_nsa_ks_ke.py new file mode 100644 index 0000000000..6567479c25 --- /dev/null +++ b/unit_tests/common/basemodel/triton_kernel/test_gen_nsa_ks_ke.py @@ -0,0 +1,87 @@ +import torch +import pytest +from lightllm.common.basemodel.triton_kernel.gen_nsa_ks_ke import gen_nsa_ks_ke, gen_same_req_mark + + +def test_gen_nsa_ks_ke_basic(): + """Test basic functionality of gen_nsa_ks_ke with simple inputs.""" + # Setup test data + b_seq_len = torch.tensor( + [ + 10, + ], + dtype=torch.int32, + device="cuda", + ) + b_q_seq_len = torch.tensor( + [ + 5, + ], + dtype=torch.int32, + device="cuda", + ) + b_req_idx = torch.tensor( + [ + 1, + ], + dtype=torch.int32, + device="cuda", + ) + q_token_num = b_q_seq_len.sum().item() + + req_to_token_index = torch.arange(0, 1000).cuda().view(100, -1) + ragged_mem_index = torch.empty_like(req_to_token_index.view(-1)) + + ks, ke, lengths = gen_nsa_ks_ke( + b_seq_len, b_q_seq_len, b_req_idx, req_to_token_index, q_token_num, ragged_mem_index + ) + + assert torch.equal(ks, torch.tensor([0, 0, 0, 0, 0], dtype=torch.int32, device="cuda")) + assert torch.equal(ke, torch.tensor([5, 6, 7, 8, 9], dtype=torch.int32, device="cuda")) + assert torch.equal(lengths, ke - ks + 1) + assert torch.equal( + ragged_mem_index[0:10], torch.tensor([10, 11, 12, 13, 14, 15, 16, 17, 18, 19], dtype=torch.int32, device="cuda") + ) + return + + +def test_gen_nsa_ks_ke_batch(): + b_seq_len = torch.tensor([10, 11], dtype=torch.int32, device="cuda") + b_q_seq_len = torch.tensor([1, 1], dtype=torch.int32, device="cuda") + b_req_idx = torch.tensor([1, 1], dtype=torch.int32, device="cuda") + q_token_num = b_q_seq_len.sum().item() + + req_to_token_index = torch.arange(0, 1000).cuda().view(10, -1) + ragged_mem_index = torch.empty_like(req_to_token_index.view(-1)) + + ks, ke, lengths = gen_nsa_ks_ke( + b_seq_len, b_q_seq_len, b_req_idx, req_to_token_index, q_token_num, ragged_mem_index + ) + + assert torch.equal( + ks, + torch.tensor( + [ + 0, + 0, + ], + dtype=torch.int32, + device="cuda", + ), + ) + assert torch.equal(ke, torch.tensor([9, 10], dtype=torch.int32, device="cuda")) + assert torch.equal(lengths, ke - ks + 1) + assert torch.equal(ragged_mem_index[0:11], torch.arange(100, 100 + 11).cuda()) + + +def test_gen_same_req_mark(): + b_req_idx = torch.tensor([0, 0, 1, 1, 1, 2], dtype=torch.int32, device="cuda") + expected_same_req_mark = torch.tensor([0, 2, 0, 0, 3, 1], dtype=torch.int32, device="cuda") + + same_req_mark = gen_same_req_mark(b_req_idx) + + assert torch.equal(same_req_mark, expected_same_req_mark) + + +if __name__ == "__main__": + pytest.main() diff --git a/unit_tests/models/deepseek3_2/triton_kernel/test_topk_index_to_mem_index.py b/unit_tests/models/deepseek3_2/triton_kernel/test_topk_index_to_mem_index.py new file mode 100644 index 0000000000..5e3aaf38d0 --- /dev/null +++ b/unit_tests/models/deepseek3_2/triton_kernel/test_topk_index_to_mem_index.py @@ -0,0 +1,24 @@ +import torch +import pytest +from lightllm.models.deepseek3_2.triton_kernel.topk_index_to_mem_index import trans_topk_index_to_mem_index + + +def test_trans_topk_index_to_mem_index(): + """Test trans_topk_index_to_mem_index converts topk indices to memory indices correctly.""" + batch_size = 1 + topk = 2048 + + # Create topk_index tensor with some valid indices and some -1 (padding) + topk_index = torch.zeros((batch_size, topk), dtype=torch.int32, device="cuda") + topk_index[:, 0:2048] = torch.arange(0, 2048, dtype=torch.int32, device="cuda") + + # Create ragged_mem_index lookup table + ragged_mem_index = torch.arange(0, 2048, dtype=torch.int32, device="cuda") + 10 + + topk_mem_index = trans_topk_index_to_mem_index(topk_index, ragged_mem_index) + + assert torch.equal(topk_mem_index, (torch.arange(0, 2048, dtype=torch.int32, device="cuda") + 10).view(1, -1)) + + +if __name__ == "__main__": + pytest.main()