From c24e04aa52b077dd885d9cdfb061f94c39a47f32 Mon Sep 17 00:00:00 2001 From: sufubao Date: Wed, 5 Nov 2025 03:55:14 +0000 Subject: [PATCH 01/58] support deepseek v3.2 --- .../layer_weights/transformer_layer_weight.py | 11 +- lightllm/models/deepseek3_2/infer_struct.py | 9 ++ .../layer_infer/nsa_indexer_layer_inder.py | 142 ++++++++++++++++++ .../layer_infer/transformer_layer_infer.py | 127 ++++++++++++++++ .../layer_weights/nsa_indexer_layer_weight.py | 49 ++++++ .../layer_weights/transformer_layer_weight.py | 16 ++ lightllm/models/deepseek3_2/mem_manager.py | 47 ++++++ lightllm/models/deepseek3_2/model.py | 38 +++++ .../deepseek3_2/triton_kernel/__init__.py | 0 .../deepseek3_2/triton_kernel/act_quant.py | 137 +++++++++++++++++ .../triton_kernel/token_group_quant.py | 103 +++++++++++++ 11 files changed, 678 insertions(+), 1 deletion(-) create mode 100644 lightllm/models/deepseek3_2/infer_struct.py create mode 100644 lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py create mode 100644 lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py create mode 100644 lightllm/models/deepseek3_2/layer_weights/nsa_indexer_layer_weight.py create mode 100644 lightllm/models/deepseek3_2/layer_weights/transformer_layer_weight.py create mode 100644 lightllm/models/deepseek3_2/mem_manager.py create mode 100644 lightllm/models/deepseek3_2/model.py create mode 100644 lightllm/models/deepseek3_2/triton_kernel/__init__.py create mode 100644 lightllm/models/deepseek3_2/triton_kernel/act_quant.py create mode 100644 lightllm/models/deepseek3_2/triton_kernel/token_group_quant.py diff --git a/lightllm/common/basemodel/layer_weights/transformer_layer_weight.py b/lightllm/common/basemodel/layer_weights/transformer_layer_weight.py index 86a887a259..0f4d6b13ae 100644 --- a/lightllm/common/basemodel/layer_weights/transformer_layer_weight.py +++ b/lightllm/common/basemodel/layer_weights/transformer_layer_weight.py @@ -36,11 +36,20 @@ def load_hf_weights(self, weights): """ for attr_name in dir(self): attr = getattr(self, attr_name, None) - if isinstance(attr, MMWeightTpl) and len(attr.weight_names) >= 2: + if isinstance(attr, TransformerLayerWeight): + attr.load_hf_weights(weights) + elif isinstance(attr, MMWeightTpl) and len(attr.weight_names) >= 2: with self.lock: attr.load_hf_weights(weights) elif isinstance(attr, BaseWeight): attr.load_hf_weights(weights) + def verify_load(self): + for attr_name in dir(self): + attr = getattr(self, attr_name, None) + if isinstance(attr, TransformerLayerWeight): + attr.verify_load() + super().verify_load() + def get_quant_method(self, name): return self.quant_cfg.get_quant_method(self.layer_num_, name) diff --git a/lightllm/models/deepseek3_2/infer_struct.py b/lightllm/models/deepseek3_2/infer_struct.py new file mode 100644 index 0000000000..6e5e766b25 --- /dev/null +++ b/lightllm/models/deepseek3_2/infer_struct.py @@ -0,0 +1,9 @@ +import os +import torch +import numpy as np +import torch.distributed as dist +from lightllm.models.deepseek2.flashattention_infer_struct import Deepseek2FlashAttentionStateInfo + + +class Deepseek3_2FlashAttentionInferStateInfo(Deepseek2FlashAttentionStateInfo): + pass \ No newline at end of file diff --git a/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py b/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py new file mode 100644 index 0000000000..a3891f0f3e --- /dev/null +++ b/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py @@ -0,0 +1,142 @@ +from sgl_kernel import fast_topk_transform_fused +import deep_gemm +import torch +import torch.nn.functional as F + +from lightllm.common.basemodel.layer_infer.base_layer_infer import BaseLayerInfer +from lightllm.models.deepseek3_2.layer_weights.nsa_indexer_layer_weight import NSAIndexerWeight +from lightllm.models.deepseek3_2.infer_struct import Deepseek3_2FlashAttentionInferStateInfo +from lightllm.models.deepseek2.triton_kernel.rotary_emb import rotary_emb_fwd +from lightllm.models.deepseek3_2.triton_kernel.act_quant import act_quant + + +class NSAIndexerInfer(BaseLayerInfer): + def __init__(self, layer_idx, network_config, mode=[]): + super().__init__() + self.layer_idx_ = layer_idx + self.network_config_ = network_config + self.mode = mode + self.index_topk = network_config["index_topk"] + self.tp_q_head_num_ = network_config["num_attention_heads"] // self.tp_world_size_ + self.tp_k_head_num_ = 1 + self.tp_v_head_num_ = 1 + self.qk_nope_head_dim = network_config["qk_nope_head_dim"] + self.qk_rope_head_dim = network_config["qk_rope_head_dim"] + self.index_head_dim = network_config["index_head_dim"] + self.eps = network_config["rms_norm_eps"] + self.block_size = network_config["quantization_config"]["weight_block_size"][0] + self.scale_fmt = network_config["quantization_config"]["scale_fmt"] + self.softmax_scale = (self.qk_nope_head_dim + self.qk_rope_head_dim) ** (-0.5) + self.index_n_heads = network_config["index_n_heads"] + self.index_n_heads_scale = self.index_n_heads ** -0.5 + + self.q_lora = None + self.hidden_states = None + return + + def ref_fp8_mqa_logits(self, q: torch.Tensor, kv: torch.Tensor, weights: torch.Tensor, + cu_seqlen_ks: torch.Tensor, cu_seqlen_ke: torch.Tensor, cost_only: bool = False): + seq_len_kv = kv.shape[0] + + if cost_only: + start = cu_seqlen_ks.clamp(min=0, max=seq_len_kv) + end = cu_seqlen_ke.clamp(min=0, max=seq_len_kv) + count_ones_per_row = (end - start).clamp(min=0) + return count_ones_per_row.sum() + + k = kv + q = q.float() + k = k.float() + + mask_lo = torch.arange(0, seq_len_kv, device='cuda')[None, :] >= cu_seqlen_ks[:, None] + mask_hi = torch.arange(0, seq_len_kv, device='cuda')[None, :] < cu_seqlen_ke[:, None] + mask = mask_lo & mask_hi + + score = torch.einsum('mhd,nd->hmn', q, k) + logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0) + logits = logits.masked_fill(~mask, float('-inf')) + + cost = mask.sum() + return logits, cost + + def get_indices(self, infer_state: Deepseek3_2FlashAttentionInferStateInfo, layer_weight: NSAIndexerWeight) -> torch.Tensor: + assert self.hidden_states is not None + assert self.q_lora is not None + + q, k = self._get_q_k_bf16(infer_state, layer_weight) + q_fp8, q_scale = act_quant(q, self.block_size, self.scale_fmt) + k_fp8, k_scale = act_quant(k, self.block_size, self.scale_fmt) + + weights = layer_weight.weights_proj_.mm(self.hidden_states) * self.index_n_heads_scale + weights = weights.unsqueeze(-1) * q_scale * self.softmax_scale + + logits = fp8_paged_mqa_logits_torch( + q_fp8, k_fp8, weights, + infer_state.lengths, + infer_state.page_table, + infer_state.max_model_len + ) + + return fast_topk_transform_fused( + score=logits, + lengths=infer_state.lengths, + page_table_size_1=infer_state.page_table, + cu_seqlens_q=infer_state.b1_cu_q_seq_len, + topk=self.index_topk + ) + + @staticmethod + def _rotate_activation(x: torch.Tensor) -> torch.Tensor: + assert x.dtype == torch.bfloat16 + from sgl_kernel import hadamard_transform + + hidden_size = x.size(-1) + assert ( + hidden_size & (hidden_size - 1) + ) == 0, "Hidden size must be a power of 2 for Hadamard transform." + return hadamard_transform(x, scale=hidden_size**-0.5) + + def _get_q_k_bf16(self, infer_state: Deepseek3_2FlashAttentionInferStateInfo, layer_weight: NSAIndexerWeight): + q = layer_weight.wq_b_proj_.mm(self.q_lora).view(-1, self.index_n_heads, self.index_head_dim) + self.q_lora = None + + k = layer_weight.wk_proj_.mm(self.hidden_states) + self.hidden_states = None + k = F.layer_norm( + k.float(), (self.index_head_dim,), layer_weight.k_norm_.weight, layer_weight.k_norm_.bias, self.eps + ).type_as(k) + + rotary_emb_fwd( + q[:, :, : self.qk_rope_head_dim], + k[:, None, : self.qk_rope_head_dim], + infer_state.position_cos, + infer_state.position_sin, + ) + + q = self._rotate_activation(q) + k = self._rotate_activation(k) + return q, k + + +# TODO +def fp8_paged_mqa_logits_torch(q: torch.Tensor, kv_cache: torch.Tensor, + weights: torch.Tensor, context_lens: torch.Tensor, block_tables: torch.Tensor, + max_model_len: int): + batch_size, next_n, heads, dim = q.size() + num_block, block_size, _, dim = kv_cache.size() + logits = torch.full([batch_size * next_n, max_model_len], float('-inf'), device=q.device, dtype=torch.float32) + context_lens = context_lens.tolist() + for i in range(batch_size): + context_len = context_lens[i] + q_offsets = torch.arange(context_len - next_n, context_len, device='cuda') + weight_slice = weights[i * next_n:(i + 1) * next_n, :].transpose(0, 1).contiguous() + for block_rk in range((context_len + block_size - 1) // block_size): + block_idx = block_tables[i][block_rk] + qx, kx = q[i], kv_cache[block_idx] + k_offsets = torch.arange(block_rk * block_size, (block_rk + 1) * block_size, device='cuda') + mask = (k_offsets[None, :] < context_len) & (k_offsets[None, :] <= q_offsets[:, None]) + s = torch.where(mask[None, :, :], (qx.transpose(0, 1) @ kx.transpose(0, 1).transpose(1, 2)).to(logits.dtype), float('-inf')) + s = torch.relu(s) * weight_slice[..., None] + s = s.sum(dim=0) + logits[i * next_n:(i + 1) * next_n, block_rk * block_size: (block_rk + 1) * block_size] = torch.where(k_offsets[None, :] <= q_offsets[:, None], s, float('-inf')) + return logits \ No newline at end of file diff --git a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py new file mode 100644 index 0000000000..6db8c14e8a --- /dev/null +++ b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py @@ -0,0 +1,127 @@ +from functools import partial +from typing import override + +import torch +from sgl_kernel.flash_mla import flash_mla_sparse_fwd +from sgl_kernel.flash_attn import flash_attn_with_kvcache + +from lightllm.models.deepseek2.layer_infer.transformer_layer_infer import Deepseek2TransformerLayerInfer +from lightllm.models.deepseek3_2.layer_infer.nsa_indexer_layer_inder import NSAIndexerInfer +from lightllm.models.deepseek3_2.layer_weights.transformer_layer_weight import Deepseek3_2TransformerLayerWeight +from lightllm.models.deepseek3_2.infer_struct import Deepseek3_2FlashAttentionInferStateInfo +from lightllm.models.deepseek3_2.triton_kernel.token_group_quant import per_token_group_quant_mla_deep_gemm_masked_fp8 +from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward +from lightllm.models.deepseek2.triton_kernel.rotary_emb import rotary_emb_fwd + + +class Deepseek3_2TransformerLayerInfer(Deepseek2TransformerLayerInfer): + def __init__(self, layer_num, network_config, mode=[]): + self.index_topk = network_config["index_topk"] + super().__init__(layer_num, network_config, mode) + + self.indexer = NSAIndexerInfer( + layer_idx=self.layer_num_, + network_config=self.network_config_, + mode=mode + ) + return + + @override + def _get_qkv( + self, + input: torch.Tensor, + infer_state: Deepseek3_2FlashAttentionInferStateInfo, + layer_weight: Deepseek3_2TransformerLayerWeight, + ) -> torch.Tensor: + input = input.view(-1, self.embed_dim_) + + if self.q_lora_rank is None: + q = layer_weight.q_weight_.mm(input) + cache_kv = layer_weight.kv_a_proj_with_mqa_.mm(input).view(-1, 1, self.kv_lora_rank + self.qk_rope_head_dim) + else: + q, cache_kv = layer_weight.qkv_a_proj_with_mqa_.mm(input).split( + [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1 + ) + q = rmsnorm_forward(q, weight=layer_weight.q_a_layernorm_.weight, eps=self.eps_) + + self.indexer.hidden_states = input + self.indexer.q_lora = q + + q = layer_weight.q_b_proj_.mm(q) + cache_kv = cache_kv.view(-1, 1, self.kv_lora_rank + self.qk_rope_head_dim) + q = q.view(-1, self.tp_q_head_num_, self.qk_nope_head_dim + self.qk_rope_head_dim) + q_nope, q_rope = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + rmsnorm_forward( + cache_kv[:, :, : self.kv_lora_rank], + weight=layer_weight.kv_a_layernorm_.weight, + eps=self.eps_, + out=cache_kv[:, :, : self.kv_lora_rank], + ) + + rotary_emb_fwd( + q_rope, + cache_kv[:, :, self.kv_lora_rank :], + infer_state.position_cos, + infer_state.position_sin, + ) + return q, cache_kv + + @override + def _bind_attention(self): + self._context_attention_kernel = partial(Deepseek3_2TransformerLayerInfer._context_attention_flashmla_kernel_with_indexer, self) + self._token_attention_kernel = partial(Deepseek3_2TransformerLayerInfer._token_attention_flashmla_kernel_with_indexer, self) + pass + + def _context_attention_flashmla_kernel_with_indexer( + self, + q: torch.Tensor, + kv, + infer_state: Deepseek3_2FlashInferStateInfo, + layer_weight: Deepseek3_2TransformerLayerWeight, + out=None, + ) -> torch.Tensor: + + q_nope, q_rope = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1) + q_all = torch.cat([q_nope, q_rope], dim=-1) + topk_indices = self.indexer.get_indices( + infer_state, + layer_weight.indexer_layer_weight, + ) + mla_out, _, _ = flash_mla_sparse_fwd( + q=q_all, + kv=infer_state.mem_manager.kv_buffer[self.layer_num_], + indices=topk_indices.unsqueeze(1), + sm_scale=self.softmax_scale, + d_v=self.kv_lora_rank, + ) + return mla_out + + def _token_attention_flashmla_kernel_with_indexer( + self, q, infer_state: Deepseek3_2FlashAttentionInferStateInfo, layer_weight: Deepseek3_2TransformerLayerWeight, out=None + ): + q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :] + q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1) + kv = infer_state.mem_manager.kv_buffer[self.layer_num_] + k_rope = kv[:, :, -self.qk_rope_head_dim :].reshape(-1, 1, 1, self.qk_rope_head_dim) + kv_nope = kv[:, :, : -self.qk_rope_head_dim].reshape(-1, 1, 1, self.kv_lora_rank) + topk_indices = self.indexer.get_indices( + infer_state, + layer_weight.indexer_layer_weight, + ) + o = flash_attn_with_kvcache( + q=q_rope, + k_cache=k_rope, + v_cache=kv_nope, + qv=q_nope, + page_table=topk_indices, + cache_seqlens=infer_state.b_att_seq_len, + cu_seqlens_q=infer_state.cu_seqlens_q, + cu_seqlens_k_new=infer_state.cu_seqlens_k, + max_seqlen_q=infer_state.max_q_seq_len, + softmax_scale=self.softmax_scale, + causal=True, + softcap=0.0, + return_softmax_lse=False, + num_splits=0, # TODO enable_deterministic_inference + ) diff --git a/lightllm/models/deepseek3_2/layer_weights/nsa_indexer_layer_weight.py b/lightllm/models/deepseek3_2/layer_weights/nsa_indexer_layer_weight.py new file mode 100644 index 0000000000..47e0bfdac5 --- /dev/null +++ b/lightllm/models/deepseek3_2/layer_weights/nsa_indexer_layer_weight.py @@ -0,0 +1,49 @@ +from typing_extensions import override + +import torch + +from lightllm.common.basemodel.layer_weights.transformer_layer_weight import TransformerLayerWeight +from lightllm.common.basemodel.layer_weights.meta_weights import ROWMMWeight, NormWeight + + +class NSAIndexerWeight(TransformerLayerWeight): + def __init__(self, layer_num, data_type, network_config, mode, quant_cfg): + super().__init__(layer_num, data_type, network_config, mode, quant_cfg) + return + + @override + def _init_weight(self): + prefix = f"model.layers.{self.layer_num_}.self_attn.indexer" + + self.wq_b_proj_ = ROWMMWeight( + weight_name=f"{prefix}.wq_b.weight", + data_type=self.data_type_, + quant_cfg=self.quant_cfg, + layer_num=self.layer_num_, + name="wq_b", + tp_rank=0, + tp_world_size=1, + ) + self.wk_proj_ = ROWMMWeight( + weight_name=f"{prefix}.wk.weight", + data_type=self.data_type_, + quant_cfg=self.quant_cfg, + layer_num=self.layer_num_, + name="wk", + tp_rank=0, + tp_world_size=1, + ) + self.k_norm_ = NormWeight( + f"{prefix}.k_norm.weight", + torch.float32, + bias_name=f"{prefix}.k_norm.bias" + ) + self.weights_proj_ = ROWMMWeight( + weight_name=f"{prefix}.weights_proj.weight", + data_type=self.data_type_, + quant_cfg=None, + layer_num=self.layer_num_, + name="weights_proj", + tp_rank=0, + tp_world_size=1, + ) diff --git a/lightllm/models/deepseek3_2/layer_weights/transformer_layer_weight.py b/lightllm/models/deepseek3_2/layer_weights/transformer_layer_weight.py new file mode 100644 index 0000000000..2a03e1d6a1 --- /dev/null +++ b/lightllm/models/deepseek3_2/layer_weights/transformer_layer_weight.py @@ -0,0 +1,16 @@ +from lightllm.models.deepseek2.layer_weights.transformer_layer_weight import Deepseek2TransformerLayerWeight +from lightllm.models.deepseek3_2.layer_weights.nsa_indexer_layer_weight import NSAIndexerWeight + + +class Deepseek3_2TransformerLayerWeight(Deepseek2TransformerLayerWeight): + def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None): + self.index_topk = network_config["index_topk"] + super().__init__(layer_num, data_type, network_config, mode, quant_cfg) + self.indexer_layer_weight = NSAIndexerWeight( + layer_num=layer_num, + data_type=data_type, + network_config=network_config, + mode=mode, + quant_cfg=quant_cfg + ) + return diff --git a/lightllm/models/deepseek3_2/mem_manager.py b/lightllm/models/deepseek3_2/mem_manager.py new file mode 100644 index 0000000000..0aa0a0bdbd --- /dev/null +++ b/lightllm/models/deepseek3_2/mem_manager.py @@ -0,0 +1,47 @@ +from typing_extensions import override +import torch + +from lightllm.common.deepseek2_mem_manager import Deepseek2MemoryManager + + +class Deepseek3_2MemoryManager(Deepseek2MemoryManager): + def __init__( + self, + size, + dtype, + head_num, + head_dim, + layer_num, + index_head_dim, + index_quant_block_size, + k_cache_dtype=torch.float8_e4m3fn, + k_scale_dtype=torch.float32, + always_copy=False, + mem_fraction=0.9 + ): + super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy, mem_fraction) + assert index_head_dim % index_quant_block_size == 0, "index_head_dim must be divisible by index_quant_block_size" + self.index_head_dim = index_head_dim + self.index_quant_block_size = index_quant_block_size + self.k_cache_dtype = k_cache_dtype + self.k_scale_dtype = k_scale_dtype + return + + @override + def get_cell_size(self): + index_k_cache_cell_size = self.index_head_dim * self.layer_num * torch._utils._element_size(self.k_cache_dtype) + index_k_scale_cell_size = (self.index_head_dim // self.index_quant_block_size) * self.layer_num * torch._utils._element_size(self.k_scale_dtype) + return super().get_cell_size() + index_k_cache_cell_size + index_k_scale_cell_size + + @override + def _init_buffers(self, size, dtype, head_num, head_dim, layer_num): + super()._init_buffers(size, dtype, head_num, head_dim, layer_num) + self._init_indexer_k_cache_buffers() + return + + def _init_indexer_k_cache_buffers(self): + self.indexer_k_cache_buffers = torch.empty( + (self.layer_num, self.size + 1, self.index_head_dim), dtype=self.k_cache_dtype, device="cuda") + self.indexer_k_scale_buffers = torch.empty( + (self.layer_num, self.size + 1, self.index_head_dim // self.index_quant_block_size), dtype=self.k_scale_dtype, device="cuda") + return diff --git a/lightllm/models/deepseek3_2/model.py b/lightllm/models/deepseek3_2/model.py new file mode 100644 index 0000000000..3a244c77f7 --- /dev/null +++ b/lightllm/models/deepseek3_2/model.py @@ -0,0 +1,38 @@ +from lightllm.models.registry import ModelRegistry +from lightllm.models.deepseek2.model import Deepseek2TpPartModel +from lightllm.models.deepseek3_2.layer_weights.transformer_layer_weight import Deepseek3_2TransformerLayerWeight +from lightllm.models.deepseek3_2.layer_infer.transformer_layer_infer import Deepseek3_2TransformerLayerInfer +from lightllm.utils.envs_utils import get_env_start_args +from lightllm.models.deepseek3_2.mem_manager import Deepseek3_2MemoryManager +from lightllm.models.deepseek3_2.infer_struct import Deepseek3_2FlashInferStateInfo + +@ModelRegistry(["deepseek_v32"]) +class Deepseek3_2TpPartModel(Deepseek2TpPartModel): + # weight class + transformer_weight_class = Deepseek3_2TransformerLayerWeight + + # infer class + transformer_layer_infer_class = Deepseek3_2TransformerLayerInfer + + # infer state class + infer_state_class = Deepseek3_2FlashInferStateInfo + + def _init_mem_manager(self): + # mtp 模式下需要在mem manger上扩展draft model使用的layer + added_mtp_layer_num = 0 + if get_env_start_args().mtp_mode == "deepseekv3_eagle": + added_mtp_layer_num += 1 + elif get_env_start_args().mtp_mode == "deepseekv3_vanilla": + added_mtp_layer_num += get_env_start_args().mtp_step + + self.mem_manager = Deepseek3_2MemoryManager( + self.max_total_token_num, + dtype=self.data_type, + head_num=1, + head_dim=self.config["kv_lora_rank"] + self.config["qk_rope_head_dim"], + layer_num=self.config["num_hidden_layers"] + added_mtp_layer_num, + index_head_dim = self.config["index_head_dim"], + index_quant_block_size = self.config["index_quant_block_size"], + mem_fraction=self.mem_fraction, + ) + return \ No newline at end of file diff --git a/lightllm/models/deepseek3_2/triton_kernel/__init__.py b/lightllm/models/deepseek3_2/triton_kernel/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/models/deepseek3_2/triton_kernel/act_quant.py b/lightllm/models/deepseek3_2/triton_kernel/act_quant.py new file mode 100644 index 0000000000..a4ecd0f518 --- /dev/null +++ b/lightllm/models/deepseek3_2/triton_kernel/act_quant.py @@ -0,0 +1,137 @@ +# Adapted from https://github.com/sgl-project/sglang/blob/ce6b17c0f94e6bf53633c8f324176a891e67fa7f/python/sglang/srt/layers/attention/nsa/triton_kernel.py +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + + +# Triton implementation +@triton.jit +def _act_quant_kernel( + X_ptr, + Y_ptr, + S_ptr, + M, + N, + group_size: tl.constexpr, + round_scale: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + """ + Triton kernel for activation quantization. + + Each block processes BLOCK_M rows and group_size columns. + """ + # Get block IDs + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + # FP8 constants + fp8_min = -448.0 + fp8_max = 448.0 + fp8_max_inv = 1.0 / fp8_max + + # Calculate row and column offsets + row_start = pid_m * BLOCK_M + col_start = pid_n * group_size + + # Create offset arrays + rows = row_start + tl.arange(0, BLOCK_M) + cols = col_start + tl.arange(0, BLOCK_N) + + # Mask for valid rows and columns + row_mask = rows < M + col_mask = cols < N + mask = row_mask[:, None] & col_mask[None, :] + + # Load input data + x_ptrs = X_ptr + rows[:, None] * N + cols[None, :] + x = tl.load(x_ptrs, mask=mask, other=0.0).to(tl.float32) + + # Compute absolute max along columns (group_size dimension) for each row + x_abs = tl.abs(x) + amax = tl.max(x_abs, axis=1) # Shape: (BLOCK_M,) + + # Clamp amax to avoid division by zero + amax = tl.maximum(amax, 1e-4) + + # Compute scale + if round_scale: + # Fast round scale using bit manipulation approximation + # This is a simplified version - the exact bit manipulation is harder in Triton + # Using log2 + ceil + pow2 as approximation + log_val = tl.log2(amax * fp8_max_inv) + log_ceil = tl.ceil(log_val) + scale = tl.exp2(log_ceil) + else: + scale = amax * fp8_max_inv + + # Quantize: y = clamp(x / scale, fp8_min, fp8_max) + scale_broadcast = scale[:, None] + y = x / scale_broadcast + y = tl.minimum(tl.maximum(y, fp8_min), fp8_max) + + # Store quantized output + y_ptrs = Y_ptr + rows[:, None] * N + cols[None, :] + tl.store(y_ptrs, y, mask=mask) + + # Store scales + s_cols = pid_n + s_ptrs = S_ptr + rows * (N // group_size) + s_cols + s_mask = row_mask + tl.store(s_ptrs, scale, mask=s_mask) + + +def act_quant( + x: torch.Tensor, block_size: int = 128, scale_fmt: Optional[str] = None +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Quantizes the input tensor `x` using block-wise quantization with Triton. + + Args: + x (torch.Tensor): The input tensor to be quantized. Must be contiguous and its last dimension size must be divisible by `block_size`. + block_size (int, optional): The size of the blocks to be used for quantization. Default is 128. + scale_fmt (Optional[str], optional): The format of the scale. Default is None. + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple containing: + - The quantized tensor with dtype `torch.float8_e4m3fn`. + - A tensor of scaling factors with dtype `torch.float32`. + """ + assert x.is_contiguous(), "Input tensor must be contiguous" + assert ( + x.size(-1) % block_size == 0 + ), f"Last dimension size must be divisible by block_size (block_size={block_size})" + + # Flatten all dims except last + N = x.size(-1) + x_flat = x.view(-1, N) + M = x_flat.size(0) + + # Allocate output tensors + y = torch.empty_like(x, dtype=torch.float8_e4m3fn) + y_flat = y.view(-1, N) + s = x.new_empty(*x.size()[:-1], N // block_size, dtype=torch.float32) + s_flat = s.view(-1, N // block_size) + + # Launch kernel + BLOCK_M = 32 + BLOCK_N = block_size + grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, block_size)) + round_scale = scale_fmt is not None + + _act_quant_kernel[grid]( + x_flat, + y_flat, + s_flat, + M, + N, + group_size=block_size, + round_scale=round_scale, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + num_stages=0 if round_scale else 2, + ) + + return y, s diff --git a/lightllm/models/deepseek3_2/triton_kernel/token_group_quant.py b/lightllm/models/deepseek3_2/triton_kernel/token_group_quant.py new file mode 100644 index 0000000000..dbf5c51992 --- /dev/null +++ b/lightllm/models/deepseek3_2/triton_kernel/token_group_quant.py @@ -0,0 +1,103 @@ +import triton +import triton.language as tl +import torch +from typing import Tuple + +fp8_min = -448.0 +fp8_max = 448.0 +fp8_dtype = torch.float8_e4m3fn + +@triton.jit +def _per_token_group_quant_mla_deep_gemm_masked_fp8( + y_ptr, + y_q_ptr, + y_s_ptr, + masked_m_ptr, + group_size, + y_stride_b, + y_stride_t, + y_q_stride_b, + y_q_stride_t, + y_s_stride_b, + y_s_stride_g, + eps, + fp8_min, + fp8_max, + NUM_GROUP: tl.constexpr, + BLOCK: tl.constexpr, +): + """A Triton-accelerated function to perform per-token-group + quantization on a tensor for deep_gemm grouped_gemm_masked. + This function converts the tensor values into float8 values. + y and y_q: (b, t, k) + y_s: (b, k//group_size, t) + """ + t_id = tl.program_id(0) + b_id = tl.program_id(1) + + y_ptr += b_id * y_stride_b + t_id * y_stride_t + y_q_ptr += b_id * y_q_stride_b + t_id * y_q_stride_t + y_s_ptr += b_id * y_s_stride_b + t_id + + if t_id == 0: + tl.store(masked_m_ptr + b_id, tl.num_programs(0)) + + cols = tl.arange(0, BLOCK) # group_size <= BLOCK + mask = cols < group_size + + for gid in range(NUM_GROUP): + y = tl.load(y_ptr + gid * group_size + cols, mask=mask, other=0.0).to( + tl.float32 + ) + _absmax = tl.maximum(tl.max(tl.abs(y)), eps) + y_s = _absmax / fp8_max + y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) + + tl.store(y_q_ptr + gid * group_size + cols, y_q, mask=mask) + tl.store(y_s_ptr + gid * y_s_stride_g, y_s) + + +def per_token_group_quant_mla_deep_gemm_masked_fp8( + x: torch.Tensor, + group_size: int = 128, + eps: float = 1e-12, + dtype: torch.dtype = fp8_dtype, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + This function quantizes input values to float8 values with per-token-group-quantization + for deep_gemm grouped_gemm_masked and specialized for mla absorbed case. + """ + assert x.dim() == 3, "`x` is not a 3d-tensor" + + b, m, k = x.shape + aligned_m = (m + 255) // 256 * 256 # 256 is the max block_m of the gemm kernel + num_tiles_k = k // group_size + assert num_tiles_k * group_size == k, f"k % {group_size} must be zero" + + x_q = x.new_empty((b, aligned_m, k), dtype=dtype) + x_s = x.new_empty((b, num_tiles_k, aligned_m), dtype=torch.float32) + masked_m = x.new_empty((b,), dtype=torch.int32) + + BLOCK_SIZE = triton.next_power_of_2(group_size) + grid = (m, b) + + _per_token_group_quant_mla_deep_gemm_masked_fp8[grid]( + x, + x_q, + x_s, + masked_m, + group_size, + x.stride(0), + x.stride(1), + x_q.stride(0), + x_q.stride(1), + x_s.stride(0), + x_s.stride(1), + eps, + -fp8_max, + fp8_max, + num_tiles_k, + BLOCK_SIZE, + ) + + return x_q, x_s.transpose(1, 2), masked_m, m, aligned_m \ No newline at end of file From 8926ac917ed4d98ddf79cd9db94178f51c74cfcb Mon Sep 17 00:00:00 2001 From: sufubao Date: Wed, 5 Nov 2025 06:22:57 +0000 Subject: [PATCH 02/58] fix --- lightllm/models/deepseek3_2/infer_struct.py | 7 ++- .../layer_infer/nsa_indexer_layer_inder.py | 9 ++- .../layer_infer/transformer_layer_infer.py | 2 +- lightllm/models/deepseek3_2/mem_manager.py | 63 +++++++++++++------ lightllm/models/deepseek3_2/model.py | 2 - 5 files changed, 57 insertions(+), 26 deletions(-) diff --git a/lightllm/models/deepseek3_2/infer_struct.py b/lightllm/models/deepseek3_2/infer_struct.py index 6e5e766b25..20f8b7e8d6 100644 --- a/lightllm/models/deepseek3_2/infer_struct.py +++ b/lightllm/models/deepseek3_2/infer_struct.py @@ -3,7 +3,12 @@ import numpy as np import torch.distributed as dist from lightllm.models.deepseek2.flashattention_infer_struct import Deepseek2FlashAttentionStateInfo +from lightllm.models.deepseek3_2.mem_manager import Deepseek3_2IndexerPagedMemoryManager, Deepseek3_2MemoryManager class Deepseek3_2FlashAttentionInferStateInfo(Deepseek2FlashAttentionStateInfo): - pass \ No newline at end of file + + def __init__(self): + super().__init__() + assert isinstance(self.req_manager.mem_manager, Deepseek3_2MemoryManager) + self.indexer_paged_mem_manager : Deepseek3_2IndexerPagedMemoryManager = self.req_manager.mem_manager.indexer_paged_mem_manager diff --git a/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py b/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py index a3891f0f3e..100df16f94 100644 --- a/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py +++ b/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py @@ -28,7 +28,7 @@ def __init__(self, layer_idx, network_config, mode=[]): self.scale_fmt = network_config["quantization_config"]["scale_fmt"] self.softmax_scale = (self.qk_nope_head_dim + self.qk_rope_head_dim) ** (-0.5) self.index_n_heads = network_config["index_n_heads"] - self.index_n_heads_scale = self.index_n_heads ** -0.5 + self.index_n_heads_scale = (self.index_n_heads ** -0.5) * self.softmax_scale self.q_lora = None self.hidden_states = None @@ -67,8 +67,13 @@ def get_indices(self, infer_state: Deepseek3_2FlashAttentionInferStateInfo, laye q_fp8, q_scale = act_quant(q, self.block_size, self.scale_fmt) k_fp8, k_scale = act_quant(k, self.block_size, self.scale_fmt) + # write + # infer_state.mem_manager. + + # read + weights = layer_weight.weights_proj_.mm(self.hidden_states) * self.index_n_heads_scale - weights = weights.unsqueeze(-1) * q_scale * self.softmax_scale + weights = weights.unsqueeze(-1) * q_scale logits = fp8_paged_mqa_logits_torch( q_fp8, k_fp8, weights, diff --git a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py index 6db8c14e8a..9f503e9bda 100644 --- a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py @@ -76,7 +76,7 @@ def _context_attention_flashmla_kernel_with_indexer( self, q: torch.Tensor, kv, - infer_state: Deepseek3_2FlashInferStateInfo, + infer_state: Deepseek3_2FlashAttentionInferStateInfo, layer_weight: Deepseek3_2TransformerLayerWeight, out=None, ) -> torch.Tensor: diff --git a/lightllm/models/deepseek3_2/mem_manager.py b/lightllm/models/deepseek3_2/mem_manager.py index 0aa0a0bdbd..f2613aacc2 100644 --- a/lightllm/models/deepseek3_2/mem_manager.py +++ b/lightllm/models/deepseek3_2/mem_manager.py @@ -1,9 +1,37 @@ from typing_extensions import override import torch +from lightllm.common.basemodel.infer_struct import InferStateInfo from lightllm.common.deepseek2_mem_manager import Deepseek2MemoryManager +from lightllm.utils.log_utils import init_logger +logger = init_logger(__name__) +class Deepseek3_2IndexerPagedMemoryManager: + def __init__(self, page_size): + self.page_size = page_size + return + + def set_size(self, size): + self.physics_size = size + self.num_pages = size // self.page_size + return + + def _init_buffers(self): + self.k_cache_buffer = torch.empty( + (self.page_size, 128), dtype=torch.float8_e4m3fn, device="cuda") + self.k_scale_buffer = torch.empty( + (self.page_size, 1), dtype=torch.float64, device="cuda") + return + + def alloc_paged_index(self, last_index: int, need_size): + pass + + def get_cell_size(self): + # Use for deepseek v3.2 exp only, 128 for k_cache(128 torch.float8_e4m3fn), 4 for scale(1 torch.float64) + return 128 + 4 + + class Deepseek3_2MemoryManager(Deepseek2MemoryManager): def __init__( self, @@ -12,36 +40,31 @@ def __init__( head_num, head_dim, layer_num, - index_head_dim, - index_quant_block_size, - k_cache_dtype=torch.float8_e4m3fn, - k_scale_dtype=torch.float32, always_copy=False, - mem_fraction=0.9 + mem_fraction=0.9, + page_size=64 ): + self.page_size = page_size + self.indexer_paged_mem_manager = Deepseek3_2IndexerPagedMemoryManager(page_size) super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy, mem_fraction) - assert index_head_dim % index_quant_block_size == 0, "index_head_dim must be divisible by index_quant_block_size" - self.index_head_dim = index_head_dim - self.index_quant_block_size = index_quant_block_size - self.k_cache_dtype = k_cache_dtype - self.k_scale_dtype = k_scale_dtype + self.indexer_paged_mem_manager.set_size(self.size) return @override def get_cell_size(self): - index_k_cache_cell_size = self.index_head_dim * self.layer_num * torch._utils._element_size(self.k_cache_dtype) - index_k_scale_cell_size = (self.index_head_dim // self.index_quant_block_size) * self.layer_num * torch._utils._element_size(self.k_scale_dtype) - return super().get_cell_size() + index_k_cache_cell_size + index_k_scale_cell_size + return super().get_cell_size() + self.indexer_paged_mem_manager.get_cell_size() @override def _init_buffers(self, size, dtype, head_num, head_dim, layer_num): super()._init_buffers(size, dtype, head_num, head_dim, layer_num) - self._init_indexer_k_cache_buffers() + self.indexer_paged_mem_manager._init_buffers() return - def _init_indexer_k_cache_buffers(self): - self.indexer_k_cache_buffers = torch.empty( - (self.layer_num, self.size + 1, self.index_head_dim), dtype=self.k_cache_dtype, device="cuda") - self.indexer_k_scale_buffers = torch.empty( - (self.layer_num, self.size + 1, self.index_head_dim // self.index_quant_block_size), dtype=self.k_scale_dtype, device="cuda") - return + @override + def profile_size(self, mem_fraction): + super().profile_size(mem_fraction) + if self.size % self.page_size != 0: + size_paged = (self.size // self.page_size + 1) * self.page_size + logger.warning(f"size {self.size} is not divisible by page_size {self.page_size}, will use paged_size {size_paged}") + self.size = size_paged + return \ No newline at end of file diff --git a/lightllm/models/deepseek3_2/model.py b/lightllm/models/deepseek3_2/model.py index 3a244c77f7..5b3fc1f13d 100644 --- a/lightllm/models/deepseek3_2/model.py +++ b/lightllm/models/deepseek3_2/model.py @@ -31,8 +31,6 @@ def _init_mem_manager(self): head_num=1, head_dim=self.config["kv_lora_rank"] + self.config["qk_rope_head_dim"], layer_num=self.config["num_hidden_layers"] + added_mtp_layer_num, - index_head_dim = self.config["index_head_dim"], - index_quant_block_size = self.config["index_quant_block_size"], mem_fraction=self.mem_fraction, ) return \ No newline at end of file From d1956ccc14193650bd98aa8cdc87ea126436d23e Mon Sep 17 00:00:00 2001 From: sufubao Date: Thu, 6 Nov 2025 10:40:46 +0000 Subject: [PATCH 03/58] fix --- .../deepseek2_mem_manager.py | 4 +- .../kv_cache_mem_manager/mem_manager.py | 17 ++- lightllm/models/__init__.py | 1 + lightllm/models/deepseek3_2/infer_struct.py | 26 +++- .../layer_infer/nsa_indexer_layer_inder.py | 136 ++++++++++------- .../layer_infer/transformer_layer_infer.py | 15 +- lightllm/models/deepseek3_2/mem_manager.py | 72 ++------- lightllm/models/deepseek3_2/model.py | 13 +- .../destindex_copy_indexer_ks.py | 137 ++++++++++++++++++ .../triton_kernel/fp8_mqa_logits.py | 0 10 files changed, 283 insertions(+), 138 deletions(-) create mode 100644 lightllm/models/deepseek3_2/triton_kernel/destindex_copy_indexer_ks.py create mode 100644 lightllm/models/deepseek3_2/triton_kernel/fp8_mqa_logits.py diff --git a/lightllm/common/kv_cache_mem_manager/deepseek2_mem_manager.py b/lightllm/common/kv_cache_mem_manager/deepseek2_mem_manager.py index 3d93e1b070..ad54b39353 100644 --- a/lightllm/common/kv_cache_mem_manager/deepseek2_mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/deepseek2_mem_manager.py @@ -15,8 +15,8 @@ class Deepseek2MemoryManager(MemoryManager): - def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9): - super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy, mem_fraction) + def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9, is_sub_mem_manager=False): + super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy, mem_fraction, is_sub_mem_manager) def copy_kv_to_mem_manager(self, layer_index: int, mem_index: torch.Tensor, kv: torch.Tensor): """ diff --git a/lightllm/common/kv_cache_mem_manager/mem_manager.py b/lightllm/common/kv_cache_mem_manager/mem_manager.py index 1203cbdec7..2940d74e21 100755 --- a/lightllm/common/kv_cache_mem_manager/mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/mem_manager.py @@ -26,7 +26,7 @@ class MemoryManager: - def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9): + def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9, is_sub_mem_manager=False): self.size = size self.head_num = head_num self.head_dim = head_dim @@ -48,15 +48,16 @@ def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False self.can_use_mem_size = self.size - # 用共享内存进行共享,router 模块读取进行精确的调度估计, nccl port 作为一个单机中单实列的标记。防止冲突。 - from lightllm.utils.envs_utils import get_unique_server_name + if not is_sub_mem_manager: + # 用共享内存进行共享,router 模块读取进行精确的调度估计, nccl port 作为一个单机中单实列的标记。防止冲突。 + from lightllm.utils.envs_utils import get_unique_server_name - rank_in_node = get_current_rank_in_node() - self.shared_can_use_token_num = SharedInt( - f"{get_unique_server_name()}_mem_manger_can_use_token_num_{rank_in_node}" - ) + rank_in_node = get_current_rank_in_node() + self.shared_can_use_token_num = SharedInt( + f"{get_unique_server_name()}_mem_manger_can_use_token_num_{rank_in_node}" + ) - self.shared_can_use_token_num.set_value(self.can_use_mem_size) + self.shared_can_use_token_num.set_value(self.can_use_mem_size) self._init_buffers( self.size, dtype, diff --git a/lightllm/models/__init__.py b/lightllm/models/__init__.py index 32ccbe8337..f2e29d4a88 100644 --- a/lightllm/models/__init__.py +++ b/lightllm/models/__init__.py @@ -18,6 +18,7 @@ from lightllm.models.gemma_2b.model import Gemma_2bTpPartModel from lightllm.models.phi3.model import Phi3TpPartModel from lightllm.models.deepseek2.model import Deepseek2TpPartModel +from lightllm.models.deepseek3_2.model import Deepseek3_2TpPartModel from lightllm.models.glm4_moe_lite.model import Glm4MoeLiteTpPartModel from lightllm.models.internvl.model import ( InternVLLlamaTpPartModel, diff --git a/lightllm/models/deepseek3_2/infer_struct.py b/lightllm/models/deepseek3_2/infer_struct.py index 20f8b7e8d6..bfdb53fd6a 100644 --- a/lightllm/models/deepseek3_2/infer_struct.py +++ b/lightllm/models/deepseek3_2/infer_struct.py @@ -1,14 +1,24 @@ -import os import torch -import numpy as np -import torch.distributed as dist from lightllm.models.deepseek2.flashattention_infer_struct import Deepseek2FlashAttentionStateInfo -from lightllm.models.deepseek3_2.mem_manager import Deepseek3_2IndexerPagedMemoryManager, Deepseek3_2MemoryManager +class Deepseek3_2FlashAttentionStateInfo(Deepseek2FlashAttentionStateInfo): -class Deepseek3_2FlashAttentionInferStateInfo(Deepseek2FlashAttentionStateInfo): - def __init__(self): super().__init__() - assert isinstance(self.req_manager.mem_manager, Deepseek3_2MemoryManager) - self.indexer_paged_mem_manager : Deepseek3_2IndexerPagedMemoryManager = self.req_manager.mem_manager.indexer_paged_mem_manager + self.lengths = None + self.page_table_size_1 = None + self.ks = None + self.ke = None + return + + def init_some_extra_state(self, model, input_ids: torch.Tensor): + super().init_some_extra_state(model, input_ids) + # Ensure b_ready_cache_len is set for both prefill and decode modes + if self.is_prefill: + # b_ready_cache_len is already set in basemodel.py for prefill + pass + else: + # In decode mode, b_ready_cache_len should be b_seq_len - b_q_seq_len + # since b_q_seq_len represents the new tokens being processed + if self.b_ready_cache_len is None: + self.b_ready_cache_len = self.b_seq_len - self.b_q_seq_len diff --git a/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py b/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py index 100df16f94..1977c211e9 100644 --- a/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py +++ b/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py @@ -5,10 +5,12 @@ from lightllm.common.basemodel.layer_infer.base_layer_infer import BaseLayerInfer from lightllm.models.deepseek3_2.layer_weights.nsa_indexer_layer_weight import NSAIndexerWeight -from lightllm.models.deepseek3_2.infer_struct import Deepseek3_2FlashAttentionInferStateInfo +from lightllm.models.deepseek3_2.infer_struct import Deepseek3_2FlashAttentionStateInfo from lightllm.models.deepseek2.triton_kernel.rotary_emb import rotary_emb_fwd from lightllm.models.deepseek3_2.triton_kernel.act_quant import act_quant - +from lightllm.models.deepseek3_2.mem_manager import Deepseek3_2MemoryManager +from lightllm.models.deepseek3_2.triton_kernel.destindex_copy_indexer_ks import destindex_copy_indexer_ks +# from lightllm.models.deepseek3_2.triton_kernel.fp8_mqa_logits import fp8_mqa_logits class NSAIndexerInfer(BaseLayerInfer): def __init__(self, layer_idx, network_config, mode=[]): @@ -30,8 +32,6 @@ def __init__(self, layer_idx, network_config, mode=[]): self.index_n_heads = network_config["index_n_heads"] self.index_n_heads_scale = (self.index_n_heads ** -0.5) * self.softmax_scale - self.q_lora = None - self.hidden_states = None return def ref_fp8_mqa_logits(self, q: torch.Tensor, kv: torch.Tensor, weights: torch.Tensor, @@ -59,7 +59,7 @@ def ref_fp8_mqa_logits(self, q: torch.Tensor, kv: torch.Tensor, weights: torch.T cost = mask.sum() return logits, cost - def get_indices(self, infer_state: Deepseek3_2FlashAttentionInferStateInfo, layer_weight: NSAIndexerWeight) -> torch.Tensor: + def get_indices(self, infer_state: Deepseek3_2FlashAttentionStateInfo, layer_weight: NSAIndexerWeight) -> torch.Tensor: assert self.hidden_states is not None assert self.q_lora is not None @@ -67,29 +67,78 @@ def get_indices(self, infer_state: Deepseek3_2FlashAttentionInferStateInfo, laye q_fp8, q_scale = act_quant(q, self.block_size, self.scale_fmt) k_fp8, k_scale = act_quant(k, self.block_size, self.scale_fmt) - # write - # infer_state.mem_manager. - - # read + self._copy_ks_to_mem_cache(k_fp8, k_scale, infer_state.mem_index, infer_state.mem_manager) weights = layer_weight.weights_proj_.mm(self.hidden_states) * self.index_n_heads_scale - weights = weights.unsqueeze(-1) * q_scale - - logits = fp8_paged_mqa_logits_torch( - q_fp8, k_fp8, weights, - infer_state.lengths, - infer_state.page_table, - infer_state.max_model_len + weights = weights.unsqueeze(-1) * q_scale + + ks_buffer = infer_state.mem_manager.indexer_ks_mem_manager.kv_buffer[self.layer_idx_] + + k_fp8_list = [] + k_scale_list = [] + ks_list = [] + ke_list = [] + offset = 0 + for i in range(infer_state.batch_size): + q_len = infer_state.b_q_seq_len[i] + cache_len = infer_state.b_ready_cache_len[i] + mem_indexes = infer_state.req_manager.req_to_token_indexs[infer_state.b_req_idx[i], :cache_len+q_len] + k_fp8 = ks_buffer[mem_indexes, 0, :128].view(torch.float8_e4m3fn).contiguous() + k_scale = ks_buffer[mem_indexes, 0, 128:].view(torch.float32).contiguous() + ks = torch.full((q_len,), offset, dtype=torch.int32, device="cuda") + ke = ks + torch.arange(q_len, dtype=torch.int32, device="cuda") + 1 + k_fp8_list.append(k_fp8) + k_scale_list.append(k_scale) + ks_list.append(ks) + ke_list.append(ke) + offset += q_len + + k_fp8 = torch.cat(k_fp8_list, dim=0).view(torch.float8_e4m3fn) + k_scale = torch.cat(k_scale_list, dim=0).view(torch.float32).squeeze(-1) + kv_fp8 = (k_fp8, k_scale) + ks = torch.cat(ks_list, dim=0) + ke = torch.cat(ke_list, dim=0) + + logits = deep_gemm.fp8_mqa_logits( + q_fp8, + kv_fp8, + weights.squeeze(-1), + ks, + ke, + clean_logits=False, ) - return fast_topk_transform_fused( - score=logits, - lengths=infer_state.lengths, - page_table_size_1=infer_state.page_table, - cu_seqlens_q=infer_state.b1_cu_q_seq_len, - topk=self.index_topk - ) - + return self.get_topk(logits, infer_state) + + def get_topk(self, logits, infer_state: Deepseek3_2FlashAttentionStateInfo): + topk_indices_list = [] + offset = 0 + + for i in range(infer_state.batch_size): + q_len = infer_state.b_q_seq_len[i] + cache_len = infer_state.b_ready_cache_len[i] + end_pos = q_len + cache_len + # Slice logits for this batch (both query and sequence dimensions) + batch_logits = logits[offset:offset + q_len, :end_pos] + topk_indices = batch_logits.topk(min(self.index_topk, end_pos), dim=-1)[1] + mem_indexes = infer_state.req_manager.req_to_token_indexs[infer_state.b_req_idx[i], :cache_len+q_len] + indices = torch.full((q_len, self.index_topk), -1, dtype=torch.int32, device="cuda") + for j in range(q_len): + indices[j, :topk_indices[j].shape[0]] = mem_indexes[topk_indices[j]] + topk_indices_list.append(indices) + offset += q_len + + topk_indices_ = torch.cat(topk_indices_list, dim=0) + + return topk_indices_ + + + def get_k_float32_from_buffer(self, buffer: torch.Tensor): + k_fp8 = buffer[:, :, :128].view(torch.float8_e4m3fn) + k_scale = buffer[:, :, 128:].view(torch.float32)[:, :, :1] + k_float32 = k_fp8.float() * k_scale + return k_float32 + @staticmethod def _rotate_activation(x: torch.Tensor) -> torch.Tensor: assert x.dtype == torch.bfloat16 @@ -101,12 +150,11 @@ def _rotate_activation(x: torch.Tensor) -> torch.Tensor: ) == 0, "Hidden size must be a power of 2 for Hadamard transform." return hadamard_transform(x, scale=hidden_size**-0.5) - def _get_q_k_bf16(self, infer_state: Deepseek3_2FlashAttentionInferStateInfo, layer_weight: NSAIndexerWeight): + def _get_q_k_bf16(self, infer_state: Deepseek3_2FlashAttentionStateInfo, layer_weight: NSAIndexerWeight): q = layer_weight.wq_b_proj_.mm(self.q_lora).view(-1, self.index_n_heads, self.index_head_dim) self.q_lora = None k = layer_weight.wk_proj_.mm(self.hidden_states) - self.hidden_states = None k = F.layer_norm( k.float(), (self.index_head_dim,), layer_weight.k_norm_.weight, layer_weight.k_norm_.bias, self.eps ).type_as(k) @@ -122,26 +170,16 @@ def _get_q_k_bf16(self, infer_state: Deepseek3_2FlashAttentionInferStateInfo, la k = self._rotate_activation(k) return q, k - -# TODO -def fp8_paged_mqa_logits_torch(q: torch.Tensor, kv_cache: torch.Tensor, - weights: torch.Tensor, context_lens: torch.Tensor, block_tables: torch.Tensor, - max_model_len: int): - batch_size, next_n, heads, dim = q.size() - num_block, block_size, _, dim = kv_cache.size() - logits = torch.full([batch_size * next_n, max_model_len], float('-inf'), device=q.device, dtype=torch.float32) - context_lens = context_lens.tolist() - for i in range(batch_size): - context_len = context_lens[i] - q_offsets = torch.arange(context_len - next_n, context_len, device='cuda') - weight_slice = weights[i * next_n:(i + 1) * next_n, :].transpose(0, 1).contiguous() - for block_rk in range((context_len + block_size - 1) // block_size): - block_idx = block_tables[i][block_rk] - qx, kx = q[i], kv_cache[block_idx] - k_offsets = torch.arange(block_rk * block_size, (block_rk + 1) * block_size, device='cuda') - mask = (k_offsets[None, :] < context_len) & (k_offsets[None, :] <= q_offsets[:, None]) - s = torch.where(mask[None, :, :], (qx.transpose(0, 1) @ kx.transpose(0, 1).transpose(1, 2)).to(logits.dtype), float('-inf')) - s = torch.relu(s) * weight_slice[..., None] - s = s.sum(dim=0) - logits[i * next_n:(i + 1) * next_n, block_rk * block_size: (block_rk + 1) * block_size] = torch.where(k_offsets[None, :] <= q_offsets[:, None], s, float('-inf')) - return logits \ No newline at end of file + def _copy_ks_to_mem_cache(self, k_fp8, k_scale, mem_index, mem_manager: Deepseek3_2MemoryManager): + # k_fp8 : [seq_len, 128] torch.fp8_e4m3 + # k_scale : [seq_len, 1] torch.float32 + # mem_index : [seq_len] torch.int32 + # buffer : [10000000, 1, 132] torch.uint8 + buffer = mem_manager.indexer_ks_mem_manager.kv_buffer[self.layer_idx_] + destindex_copy_indexer_ks( + k_fp8.unsqueeze(1), # Add head dimension: [seq_len, 1, 128] + k_scale.unsqueeze(1), # Add head dimension: [seq_len, 1, 1] + mem_index, + buffer + ) + return \ No newline at end of file diff --git a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py index 9f503e9bda..076d3965cf 100644 --- a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py @@ -8,7 +8,7 @@ from lightllm.models.deepseek2.layer_infer.transformer_layer_infer import Deepseek2TransformerLayerInfer from lightllm.models.deepseek3_2.layer_infer.nsa_indexer_layer_inder import NSAIndexerInfer from lightllm.models.deepseek3_2.layer_weights.transformer_layer_weight import Deepseek3_2TransformerLayerWeight -from lightllm.models.deepseek3_2.infer_struct import Deepseek3_2FlashAttentionInferStateInfo +from lightllm.models.deepseek3_2.infer_struct import Deepseek3_2FlashAttentionStateInfo from lightllm.models.deepseek3_2.triton_kernel.token_group_quant import per_token_group_quant_mla_deep_gemm_masked_fp8 from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward from lightllm.models.deepseek2.triton_kernel.rotary_emb import rotary_emb_fwd @@ -30,7 +30,7 @@ def __init__(self, layer_num, network_config, mode=[]): def _get_qkv( self, input: torch.Tensor, - infer_state: Deepseek3_2FlashAttentionInferStateInfo, + infer_state: Deepseek3_2FlashAttentionStateInfo, layer_weight: Deepseek3_2TransformerLayerWeight, ) -> torch.Tensor: input = input.view(-1, self.embed_dim_) @@ -68,6 +68,7 @@ def _get_qkv( @override def _bind_attention(self): + super()._bind_attention() self._context_attention_kernel = partial(Deepseek3_2TransformerLayerInfer._context_attention_flashmla_kernel_with_indexer, self) self._token_attention_kernel = partial(Deepseek3_2TransformerLayerInfer._token_attention_flashmla_kernel_with_indexer, self) pass @@ -76,7 +77,7 @@ def _context_attention_flashmla_kernel_with_indexer( self, q: torch.Tensor, kv, - infer_state: Deepseek3_2FlashAttentionInferStateInfo, + infer_state: Deepseek3_2FlashAttentionStateInfo, layer_weight: Deepseek3_2TransformerLayerWeight, out=None, ) -> torch.Tensor: @@ -87,18 +88,19 @@ def _context_attention_flashmla_kernel_with_indexer( topk_indices = self.indexer.get_indices( infer_state, layer_weight.indexer_layer_weight, - ) + ).unsqueeze(1) + mla_out, _, _ = flash_mla_sparse_fwd( q=q_all, kv=infer_state.mem_manager.kv_buffer[self.layer_num_], - indices=topk_indices.unsqueeze(1), + indices=topk_indices, sm_scale=self.softmax_scale, d_v=self.kv_lora_rank, ) return mla_out def _token_attention_flashmla_kernel_with_indexer( - self, q, infer_state: Deepseek3_2FlashAttentionInferStateInfo, layer_weight: Deepseek3_2TransformerLayerWeight, out=None + self, q, infer_state: Deepseek3_2FlashAttentionStateInfo, layer_weight: Deepseek3_2TransformerLayerWeight, out=None ): q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :] q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1) @@ -125,3 +127,4 @@ def _token_attention_flashmla_kernel_with_indexer( return_softmax_lse=False, num_splits=0, # TODO enable_deterministic_inference ) + return o \ No newline at end of file diff --git a/lightllm/models/deepseek3_2/mem_manager.py b/lightllm/models/deepseek3_2/mem_manager.py index f2613aacc2..a70c762731 100644 --- a/lightllm/models/deepseek3_2/mem_manager.py +++ b/lightllm/models/deepseek3_2/mem_manager.py @@ -1,70 +1,22 @@ +from typing import List from typing_extensions import override import torch -from lightllm.common.basemodel.infer_struct import InferStateInfo +from lightllm.common.mem_manager import MemoryManager from lightllm.common.deepseek2_mem_manager import Deepseek2MemoryManager -from lightllm.utils.log_utils import init_logger +from lightllm.server.pd_io_struct import KVMoveTask +from lightllm.distributed.pynccl import PyNcclCommunicator -logger = init_logger(__name__) - -class Deepseek3_2IndexerPagedMemoryManager: - def __init__(self, page_size): - self.page_size = page_size - return - - def set_size(self, size): - self.physics_size = size - self.num_pages = size // self.page_size - return - - def _init_buffers(self): - self.k_cache_buffer = torch.empty( - (self.page_size, 128), dtype=torch.float8_e4m3fn, device="cuda") - self.k_scale_buffer = torch.empty( - (self.page_size, 1), dtype=torch.float64, device="cuda") - return - - def alloc_paged_index(self, last_index: int, need_size): - pass - - def get_cell_size(self): - # Use for deepseek v3.2 exp only, 128 for k_cache(128 torch.float8_e4m3fn), 4 for scale(1 torch.float64) - return 128 + 4 - - class Deepseek3_2MemoryManager(Deepseek2MemoryManager): - def __init__( - self, - size, - dtype, - head_num, - head_dim, - layer_num, - always_copy=False, - mem_fraction=0.9, - page_size=64 - ): - self.page_size = page_size - self.indexer_paged_mem_manager = Deepseek3_2IndexerPagedMemoryManager(page_size) - super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy, mem_fraction) - self.indexer_paged_mem_manager.set_size(self.size) + def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9 ,is_sub_mem_manager=False): + super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy, mem_fraction, is_sub_mem_manager) + self.indexer_ks_mem_manager = Deepseek2MemoryManager(self.size, torch.uint8, 1, 132, layer_num, is_sub_mem_manager=True) return @override def get_cell_size(self): - return super().get_cell_size() + self.indexer_paged_mem_manager.get_cell_size() - - @override - def _init_buffers(self, size, dtype, head_num, head_dim, layer_num): - super()._init_buffers(size, dtype, head_num, head_dim, layer_num) - self.indexer_paged_mem_manager._init_buffers() - return - - @override - def profile_size(self, mem_fraction): - super().profile_size(mem_fraction) - if self.size % self.page_size != 0: - size_paged = (self.size // self.page_size + 1) * self.page_size - logger.warning(f"size {self.size} is not divisible by page_size {self.page_size}, will use paged_size {size_paged}") - self.size = size_paged - return \ No newline at end of file + return super().get_cell_size() + 132 + +class Deepseek3_2FP8KVMemoryManager(Deepseek3_2MemoryManager): + def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9, is_sub_mem_manager=False): + super().__init__(size, torch.uint8, head_num, head_dim + 2, layer_num, always_copy, mem_fraction, is_sub_mem_manager) \ No newline at end of file diff --git a/lightllm/models/deepseek3_2/model.py b/lightllm/models/deepseek3_2/model.py index 5b3fc1f13d..c4e56c3c15 100644 --- a/lightllm/models/deepseek3_2/model.py +++ b/lightllm/models/deepseek3_2/model.py @@ -3,9 +3,8 @@ from lightllm.models.deepseek3_2.layer_weights.transformer_layer_weight import Deepseek3_2TransformerLayerWeight from lightllm.models.deepseek3_2.layer_infer.transformer_layer_infer import Deepseek3_2TransformerLayerInfer from lightllm.utils.envs_utils import get_env_start_args -from lightllm.models.deepseek3_2.mem_manager import Deepseek3_2MemoryManager -from lightllm.models.deepseek3_2.infer_struct import Deepseek3_2FlashInferStateInfo - +from lightllm.models.deepseek3_2.infer_struct import Deepseek3_2FlashAttentionStateInfo +from lightllm.models.deepseek3_2.mem_manager import Deepseek3_2MemoryManager, Deepseek3_2FP8KVMemoryManager @ModelRegistry(["deepseek_v32"]) class Deepseek3_2TpPartModel(Deepseek2TpPartModel): # weight class @@ -15,9 +14,13 @@ class Deepseek3_2TpPartModel(Deepseek2TpPartModel): transformer_layer_infer_class = Deepseek3_2TransformerLayerInfer # infer state class - infer_state_class = Deepseek3_2FlashInferStateInfo + infer_state_class = Deepseek3_2FlashAttentionStateInfo def _init_mem_manager(self): + manager_class = Deepseek3_2MemoryManager + if "triton_fp8kv" in self.mode: + manager_class = Deepseek3_2FP8KVMemoryManager + # mtp 模式下需要在mem manger上扩展draft model使用的layer added_mtp_layer_num = 0 if get_env_start_args().mtp_mode == "deepseekv3_eagle": @@ -25,7 +28,7 @@ def _init_mem_manager(self): elif get_env_start_args().mtp_mode == "deepseekv3_vanilla": added_mtp_layer_num += get_env_start_args().mtp_step - self.mem_manager = Deepseek3_2MemoryManager( + self.mem_manager = manager_class( self.max_total_token_num, dtype=self.data_type, head_num=1, diff --git a/lightllm/models/deepseek3_2/triton_kernel/destindex_copy_indexer_ks.py b/lightllm/models/deepseek3_2/triton_kernel/destindex_copy_indexer_ks.py new file mode 100644 index 0000000000..a098795fb7 --- /dev/null +++ b/lightllm/models/deepseek3_2/triton_kernel/destindex_copy_indexer_ks.py @@ -0,0 +1,137 @@ +import torch + +import triton +import triton.language as tl + + +@triton.jit +def _fwd_kernel_destindex_copy_indexer_ks( + k_fp8, + k_scale, + mem_index, + buffer_fp8, + buffer_scale, + stride_k_fp8_bs, + stride_k_fp8_h, + stride_k_fp8_d, + stride_k_scale_bs, + stride_k_scale_h, + stride_k_scale_d, + stride_buffer_fp8_bs, + stride_buffer_fp8_h, + stride_buffer_fp8_d, + stride_buffer_scale_bs, + stride_buffer_scale_h, + stride_buffer_scale_d, + head_num, + BLOCK_DMODEL: tl.constexpr, + BLOCK_HEAD: tl.constexpr, +): + cur_index = tl.program_id(0) + offs_h = tl.arange(0, BLOCK_HEAD) + offs_d = tl.arange(0, BLOCK_DMODEL) + + dest_index = tl.load(mem_index + cur_index).to(tl.int64) + + # Load k_fp8 data + k_fp8_ptrs = k_fp8 + cur_index * stride_k_fp8_bs + stride_k_fp8_h * offs_h[:, None] + stride_k_fp8_d * offs_d[None, :] + k_fp8_data = tl.load(k_fp8_ptrs, mask=offs_h[:, None] < head_num, other=0.0) + + # Load k_scale data + k_scale_ptrs = k_scale + cur_index * stride_k_scale_bs + stride_k_scale_h * offs_h[:, None] + stride_k_scale_d * tl.arange(0, 1)[None, :] + k_scale_data = tl.load(k_scale_ptrs, mask=offs_h[:, None] < head_num, other=0.0) + + # Store k_fp8 to buffer_fp8 + buffer_fp8_ptrs = buffer_fp8 + dest_index * stride_buffer_fp8_bs + stride_buffer_fp8_h * offs_h[:, None] + stride_buffer_fp8_d * offs_d[None, :] + tl.store(buffer_fp8_ptrs, k_fp8_data, mask=offs_h[:, None] < head_num) + + # Store k_scale to buffer_scale + buffer_scale_ptrs = buffer_scale + dest_index * stride_buffer_scale_bs + stride_buffer_scale_h * offs_h[:, None] + stride_buffer_scale_d * tl.arange(0, 1)[None, :] + tl.store(buffer_scale_ptrs, k_scale_data, mask=offs_h[:, None] < head_num) + + +@torch.no_grad() +def destindex_copy_indexer_ks(k_fp8, k_scale, mem_index, buffer): + seq_len = mem_index.shape[0] + head_num = k_fp8.shape[1] + k_fp8_dim = k_fp8.shape[2] # Should be 128 for float8 + k_scale_dim = k_scale.shape[2] # Should be 1 + + assert k_fp8.shape[1] == k_scale.shape[1] + assert k_fp8_dim == 128, f"k_fp8 dim should be 128, got {k_fp8_dim}" + assert k_scale_dim == 1, f"k_scale dim should be 1, got {k_scale_dim}" + assert buffer.shape[2] == 132, f"buffer dim should be 132, got {buffer.shape[2]}" # 128 + 4 bytes + + # Reinterpret buffer as the appropriate types for storing + buffer_fp8 = buffer[:, :, :128].view(torch.float8_e4m3fn) + buffer_scale = buffer[:, :, 128:132].view(torch.float32)[:, :, :1] + + BLOCK_HEAD = triton.next_power_of_2(head_num) + grid = (seq_len,) + num_warps = 1 + + _fwd_kernel_destindex_copy_indexer_ks[grid]( + k_fp8, + k_scale, + mem_index, + buffer_fp8, + buffer_scale, + k_fp8.stride(0), + k_fp8.stride(1), + k_fp8.stride(2), + k_scale.stride(0), + k_scale.stride(1), + k_scale.stride(2), + buffer_fp8.stride(0), + buffer_fp8.stride(1), + buffer_fp8.stride(2), + buffer_scale.stride(0), + buffer_scale.stride(1), + buffer_scale.stride(2), + head_num, + BLOCK_DMODEL=k_fp8_dim, + BLOCK_HEAD=BLOCK_HEAD, + num_warps=num_warps, + num_stages=1, + ) + return + + +def test(): + import torch.nn.functional as F + + # Test parameters similar to the usage in nsa_indexer_layer_inder.py + B, N_CTX, H, K_DIM = 4, 1024, 8, 128 # batch_size, seq_len, heads, k_dim + seq_len = 50 # number of tokens to copy + dtype_fp8 = torch.float8_e4m3fn + dtype_scale = torch.float32 + + # Create test data + k_fp8 = torch.randn((seq_len, H, K_DIM), dtype=dtype_fp8).cuda() + k_scale = torch.randn((seq_len, H, 1), dtype=dtype_scale).cuda() + mem_index = torch.randint(0, B * N_CTX, (seq_len,), dtype=torch.int32).cuda() + + # Create buffer [total_tokens, heads, 132] + buffer = torch.zeros((B * N_CTX, H, 132), dtype=torch.uint8).cuda() + + # Call the function + destindex_copy_indexer_ks(k_fp8, k_scale, mem_index, buffer) + + # Verify results + for i in range(seq_len): + dest_idx = mem_index[i].item() + # Check k_fp8 part + stored_fp8 = buffer[dest_idx, :, :128].view(dtype_fp8) + expected_fp8 = k_fp8[i] + assert torch.allclose(stored_fp8, expected_fp8, atol=1e-6), f"FP8 mismatch at index {i}" + + # Check k_scale part + stored_scale = buffer[dest_idx, :, 128:].view(dtype_scale)[:, :1] + expected_scale = k_scale[i] + assert torch.allclose(stored_scale, expected_scale, atol=1e-6), f"Scale mismatch at index {i}" + + print("All tests passed!") + + +if __name__ == "__main__": + test() diff --git a/lightllm/models/deepseek3_2/triton_kernel/fp8_mqa_logits.py b/lightllm/models/deepseek3_2/triton_kernel/fp8_mqa_logits.py new file mode 100644 index 0000000000..e69de29bb2 From 4f8a74717d30daf9ad66dfd0b87fedc1f807333f Mon Sep 17 00:00:00 2001 From: sufubao Date: Fri, 7 Nov 2025 09:16:40 +0000 Subject: [PATCH 04/58] fix --- lightllm/models/deepseek3_2/infer_struct.py | 2 + .../layer_infer/nsa_indexer_layer_inder.py | 17 +++---- .../layer_infer/transformer_layer_infer.py | 50 ++++++++----------- 3 files changed, 31 insertions(+), 38 deletions(-) diff --git a/lightllm/models/deepseek3_2/infer_struct.py b/lightllm/models/deepseek3_2/infer_struct.py index bfdb53fd6a..4d77b5f6f3 100644 --- a/lightllm/models/deepseek3_2/infer_struct.py +++ b/lightllm/models/deepseek3_2/infer_struct.py @@ -9,6 +9,8 @@ def __init__(self): self.page_table_size_1 = None self.ks = None self.ke = None + + self.topk_indices = None return def init_some_extra_state(self, model, input_ids: torch.Tensor): diff --git a/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py b/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py index 1977c211e9..3e5e1c266c 100644 --- a/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py +++ b/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py @@ -59,17 +59,16 @@ def ref_fp8_mqa_logits(self, q: torch.Tensor, kv: torch.Tensor, weights: torch.T cost = mask.sum() return logits, cost - def get_indices(self, infer_state: Deepseek3_2FlashAttentionStateInfo, layer_weight: NSAIndexerWeight) -> torch.Tensor: - assert self.hidden_states is not None - assert self.q_lora is not None + def get_indices(self, hidden_states: torch.Tensor, q_lora: torch.Tensor, + infer_state: Deepseek3_2FlashAttentionStateInfo, layer_weight: NSAIndexerWeight) -> torch.Tensor: - q, k = self._get_q_k_bf16(infer_state, layer_weight) + q, k = self._get_q_k_bf16(hidden_states, q_lora, infer_state, layer_weight) q_fp8, q_scale = act_quant(q, self.block_size, self.scale_fmt) k_fp8, k_scale = act_quant(k, self.block_size, self.scale_fmt) self._copy_ks_to_mem_cache(k_fp8, k_scale, infer_state.mem_index, infer_state.mem_manager) - weights = layer_weight.weights_proj_.mm(self.hidden_states) * self.index_n_heads_scale + weights = layer_weight.weights_proj_.mm(hidden_states) * self.index_n_heads_scale weights = weights.unsqueeze(-1) * q_scale ks_buffer = infer_state.mem_manager.indexer_ks_mem_manager.kv_buffer[self.layer_idx_] @@ -150,11 +149,11 @@ def _rotate_activation(x: torch.Tensor) -> torch.Tensor: ) == 0, "Hidden size must be a power of 2 for Hadamard transform." return hadamard_transform(x, scale=hidden_size**-0.5) - def _get_q_k_bf16(self, infer_state: Deepseek3_2FlashAttentionStateInfo, layer_weight: NSAIndexerWeight): - q = layer_weight.wq_b_proj_.mm(self.q_lora).view(-1, self.index_n_heads, self.index_head_dim) - self.q_lora = None + def _get_q_k_bf16(self, hidden_states: torch.Tensor, q_lora: torch.Tensor, + infer_state: Deepseek3_2FlashAttentionStateInfo, layer_weight: NSAIndexerWeight): + q = layer_weight.wq_b_proj_.mm(q_lora).view(-1, self.index_n_heads, self.index_head_dim) - k = layer_weight.wk_proj_.mm(self.hidden_states) + k = layer_weight.wk_proj_.mm(hidden_states) k = F.layer_norm( k.float(), (self.index_head_dim,), layer_weight.k_norm_.weight, layer_weight.k_norm_.bias, self.eps ).type_as(k) diff --git a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py index 076d3965cf..01514e96a2 100644 --- a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py @@ -1,5 +1,6 @@ from functools import partial from typing import override +from venv import logger import torch from sgl_kernel.flash_mla import flash_mla_sparse_fwd @@ -24,6 +25,7 @@ def __init__(self, layer_num, network_config, mode=[]): network_config=self.network_config_, mode=mode ) + self.topk_indices = None return @override @@ -35,20 +37,15 @@ def _get_qkv( ) -> torch.Tensor: input = input.view(-1, self.embed_dim_) - if self.q_lora_rank is None: - q = layer_weight.q_weight_.mm(input) - cache_kv = layer_weight.kv_a_proj_with_mqa_.mm(input).view(-1, 1, self.kv_lora_rank + self.qk_rope_head_dim) - else: - q, cache_kv = layer_weight.qkv_a_proj_with_mqa_.mm(input).split( - [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1 - ) - q = rmsnorm_forward(q, weight=layer_weight.q_a_layernorm_.weight, eps=self.eps_) + q, cache_kv = layer_weight.qkv_a_proj_with_mqa_.mm(input).split( + [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1 + ) + q = rmsnorm_forward(q, weight=layer_weight.q_a_layernorm_.weight, eps=self.eps_) - self.indexer.hidden_states = input - self.indexer.q_lora = q + self.topk_indices = self.indexer.get_indices(input, q, infer_state, layer_weight.indexer_layer_weight) - q = layer_weight.q_b_proj_.mm(q) - cache_kv = cache_kv.view(-1, 1, self.kv_lora_rank + self.qk_rope_head_dim) + q = layer_weight.q_b_proj_.mm(q) + cache_kv = cache_kv.view(-1, 1, self.kv_lora_rank + self.qk_rope_head_dim) q = q.view(-1, self.tp_q_head_num_, self.qk_nope_head_dim + self.qk_rope_head_dim) q_nope, q_rope = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) rmsnorm_forward( @@ -69,11 +66,11 @@ def _get_qkv( @override def _bind_attention(self): super()._bind_attention() - self._context_attention_kernel = partial(Deepseek3_2TransformerLayerInfer._context_attention_flashmla_kernel_with_indexer, self) - self._token_attention_kernel = partial(Deepseek3_2TransformerLayerInfer._token_attention_flashmla_kernel_with_indexer, self) + self._context_attention_kernel = partial(Deepseek3_2TransformerLayerInfer._nsa_context_attention_kernel, self) + self._token_attention_kernel = partial(Deepseek3_2TransformerLayerInfer._nsa_token_attention_kernel, self) pass - def _context_attention_flashmla_kernel_with_indexer( + def _nsa_context_attention_kernel( self, q: torch.Tensor, kv, @@ -85,21 +82,17 @@ def _context_attention_flashmla_kernel_with_indexer( q_nope, q_rope = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1) q_all = torch.cat([q_nope, q_rope], dim=-1) - topk_indices = self.indexer.get_indices( - infer_state, - layer_weight.indexer_layer_weight, - ).unsqueeze(1) mla_out, _, _ = flash_mla_sparse_fwd( q=q_all, kv=infer_state.mem_manager.kv_buffer[self.layer_num_], - indices=topk_indices, + indices=self.topk_indices, sm_scale=self.softmax_scale, d_v=self.kv_lora_rank, ) return mla_out - def _token_attention_flashmla_kernel_with_indexer( + def _nsa_token_attention_kernel( self, q, infer_state: Deepseek3_2FlashAttentionStateInfo, layer_weight: Deepseek3_2TransformerLayerWeight, out=None ): q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :] @@ -107,24 +100,23 @@ def _token_attention_flashmla_kernel_with_indexer( kv = infer_state.mem_manager.kv_buffer[self.layer_num_] k_rope = kv[:, :, -self.qk_rope_head_dim :].reshape(-1, 1, 1, self.qk_rope_head_dim) kv_nope = kv[:, :, : -self.qk_rope_head_dim].reshape(-1, 1, 1, self.kv_lora_rank) - topk_indices = self.indexer.get_indices( - infer_state, - layer_weight.indexer_layer_weight, - ) - o = flash_attn_with_kvcache( + k_descale, v_descale = None, None + o_tensor = flash_attn_with_kvcache( q=q_rope, k_cache=k_rope, v_cache=kv_nope, qv=q_nope, - page_table=topk_indices, + page_table=self.topk_indices, cache_seqlens=infer_state.b_att_seq_len, cu_seqlens_q=infer_state.cu_seqlens_q, cu_seqlens_k_new=infer_state.cu_seqlens_k, max_seqlen_q=infer_state.max_q_seq_len, softmax_scale=self.softmax_scale, causal=True, + window_size=(-1, -1), softcap=0.0, + k_descale=k_descale, + v_descale=v_descale, return_softmax_lse=False, - num_splits=0, # TODO enable_deterministic_inference ) - return o \ No newline at end of file + return o_tensor \ No newline at end of file From 2cf08331c85acb39c14cf01b6422691508672dc4 Mon Sep 17 00:00:00 2001 From: sufubao Date: Fri, 7 Nov 2025 10:09:56 +0000 Subject: [PATCH 05/58] need fix --- lightllm/models/deepseek3_2/__init__.py | 0 lightllm/models/deepseek3_2/infer_struct.py | 10 ++++++++-- .../layer_infer/transformer_layer_infer.py | 12 +++--------- lightllm/models/deepseek3_2/model.py | 5 +++++ 4 files changed, 16 insertions(+), 11 deletions(-) create mode 100644 lightllm/models/deepseek3_2/__init__.py diff --git a/lightllm/models/deepseek3_2/__init__.py b/lightllm/models/deepseek3_2/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/models/deepseek3_2/infer_struct.py b/lightllm/models/deepseek3_2/infer_struct.py index 4d77b5f6f3..b1e61413c4 100644 --- a/lightllm/models/deepseek3_2/infer_struct.py +++ b/lightllm/models/deepseek3_2/infer_struct.py @@ -9,8 +9,8 @@ def __init__(self): self.page_table_size_1 = None self.ks = None self.ke = None - - self.topk_indices = None + self.nsa_cu_seqlens_k = None + self.index_topk = 2048 return def init_some_extra_state(self, model, input_ids: torch.Tensor): @@ -24,3 +24,9 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor): # since b_q_seq_len represents the new tokens being processed if self.b_ready_cache_len is None: self.b_ready_cache_len = self.b_seq_len - self.b_q_seq_len + + self.nsa_cache_seqlens = self.b_att_seq_len.clamp(max=model.index_topk) + assert self.nsa_cache_seqlens.dtype == torch.int32 + self.nsa_cu_seqlens_k = torch.nn.functional.pad( + torch.cumsum(self.nsa_cache_seqlens, dim=0, dtype=torch.int32), (1, 0) + ) \ No newline at end of file diff --git a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py index 01514e96a2..188ab8b4aa 100644 --- a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py @@ -86,7 +86,7 @@ def _nsa_context_attention_kernel( mla_out, _, _ = flash_mla_sparse_fwd( q=q_all, kv=infer_state.mem_manager.kv_buffer[self.layer_num_], - indices=self.topk_indices, + indices=self.topk_indices.unsqueeze(1), sm_scale=self.softmax_scale, d_v=self.kv_lora_rank, ) @@ -100,23 +100,17 @@ def _nsa_token_attention_kernel( kv = infer_state.mem_manager.kv_buffer[self.layer_num_] k_rope = kv[:, :, -self.qk_rope_head_dim :].reshape(-1, 1, 1, self.qk_rope_head_dim) kv_nope = kv[:, :, : -self.qk_rope_head_dim].reshape(-1, 1, 1, self.kv_lora_rank) - k_descale, v_descale = None, None o_tensor = flash_attn_with_kvcache( q=q_rope, k_cache=k_rope, v_cache=kv_nope, qv=q_nope, page_table=self.topk_indices, - cache_seqlens=infer_state.b_att_seq_len, + cache_seqlens=infer_state.nsa_cache_seqlens, cu_seqlens_q=infer_state.cu_seqlens_q, - cu_seqlens_k_new=infer_state.cu_seqlens_k, + cu_seqlens_k_new=infer_state.nsa_cu_seqlens_k, max_seqlen_q=infer_state.max_q_seq_len, softmax_scale=self.softmax_scale, causal=True, - window_size=(-1, -1), - softcap=0.0, - k_descale=k_descale, - v_descale=v_descale, - return_softmax_lse=False, ) return o_tensor \ No newline at end of file diff --git a/lightllm/models/deepseek3_2/model.py b/lightllm/models/deepseek3_2/model.py index c4e56c3c15..ad7f70550f 100644 --- a/lightllm/models/deepseek3_2/model.py +++ b/lightllm/models/deepseek3_2/model.py @@ -16,6 +16,11 @@ class Deepseek3_2TpPartModel(Deepseek2TpPartModel): # infer state class infer_state_class = Deepseek3_2FlashAttentionStateInfo + def __init__(self, kvargs): + super().__init__(kvargs) + self.index_topk = self.config["index_topk"] + return + def _init_mem_manager(self): manager_class = Deepseek3_2MemoryManager if "triton_fp8kv" in self.mode: From bb8f087fba7050d9b2136503df74ae12eabf074d Mon Sep 17 00:00:00 2001 From: sufubao Date: Mon, 10 Nov 2025 04:07:54 +0000 Subject: [PATCH 06/58] run like deepseek v3 --- lightllm/models/deepseek3_2/infer_struct.py | 43 +++++- .../layer_infer/nsa_indexer_layer_inder.py | 104 ++++--------- .../layer_infer/transformer_layer_infer.py | 22 +-- lightllm/models/deepseek3_2/model.py | 5 +- .../triton_kernel/fp8_mqa_logits.py | 139 ++++++++++++++++++ 5 files changed, 224 insertions(+), 89 deletions(-) diff --git a/lightllm/models/deepseek3_2/infer_struct.py b/lightllm/models/deepseek3_2/infer_struct.py index b1e61413c4..8e5eb0b819 100644 --- a/lightllm/models/deepseek3_2/infer_struct.py +++ b/lightllm/models/deepseek3_2/infer_struct.py @@ -1,5 +1,6 @@ import torch from lightllm.models.deepseek2.flashattention_infer_struct import Deepseek2FlashAttentionStateInfo +from lightllm.models.deepseek3_2.mem_manager import Deepseek3_2MemoryManager class Deepseek3_2FlashAttentionStateInfo(Deepseek2FlashAttentionStateInfo): @@ -15,6 +16,9 @@ def __init__(self): def init_some_extra_state(self, model, input_ids: torch.Tensor): super().init_some_extra_state(model, input_ids) + assert isinstance(self.mem_manager, Deepseek3_2MemoryManager) + self.indexer_ks_mem_manager = self.mem_manager.indexer_ks_mem_manager + # Ensure b_ready_cache_len is set for both prefill and decode modes if self.is_prefill: # b_ready_cache_len is already set in basemodel.py for prefill @@ -24,9 +28,42 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor): # since b_q_seq_len represents the new tokens being processed if self.b_ready_cache_len is None: self.b_ready_cache_len = self.b_seq_len - self.b_q_seq_len - - self.nsa_cache_seqlens = self.b_att_seq_len.clamp(max=model.index_topk) + + self.nsa_cache_seqlens = self.b_att_seq_len.clamp(max=self.index_topk) assert self.nsa_cache_seqlens.dtype == torch.int32 self.nsa_cu_seqlens_k = torch.nn.functional.pad( torch.cumsum(self.nsa_cache_seqlens, dim=0, dtype=torch.int32), (1, 0) - ) \ No newline at end of file + ) + + # Pre-compute NSA indexer indexing structures + self._init_nsa_indexing_structures() + + def _init_nsa_indexing_structures(self): + """Pre-compute ks, ke, lengths, and page_table_size_1 for NSA indexer""" + mem_index_list = [] + ks_list = [] + ke_list = [] + lengths_list = [] + offset = 0 + num_seq_len = self.b_req_idx.shape[0] + self.page_table_size_1 = torch.zeros((num_seq_len, self.b_seq_len.max()), dtype=torch.int, device='cuda') + + for i in range(num_seq_len): + seq_len = self.b_seq_len[i] + q_seq_len = self.b_q_seq_len[i] + mem_index = self.req_manager.req_to_token_indexs[i, :seq_len] + mem_index_list.append(mem_index) + self.page_table_size_1[i, :seq_len] = mem_index + ks = torch.zeros(q_seq_len, dtype=torch.int, device='cuda') + offset + ke = torch.arange(q_seq_len, dtype=torch.int, device='cuda') + offset + 1 + ks_list.append(ks) + ke_list.append(ke) + lengths_list.append(torch.arange(seq_len - q_seq_len + 1, seq_len + 1, dtype=torch.int, device='cuda')) + offset += seq_len + + self.mem_index = torch.cat(mem_index_list, dim=0) + # ks : [seq_len_q] 标志kv的起始位置 + # ke : [seq_len_q] 标志kv的结束位置 + self.ks = torch.cat(ks_list, dim=0) + self.ke = torch.cat(ke_list, dim=0) + self.lengths = torch.cat(lengths_list, dim=0) \ No newline at end of file diff --git a/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py b/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py index 3e5e1c266c..d7444e9187 100644 --- a/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py +++ b/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py @@ -10,7 +10,9 @@ from lightllm.models.deepseek3_2.triton_kernel.act_quant import act_quant from lightllm.models.deepseek3_2.mem_manager import Deepseek3_2MemoryManager from lightllm.models.deepseek3_2.triton_kernel.destindex_copy_indexer_ks import destindex_copy_indexer_ks -# from lightllm.models.deepseek3_2.triton_kernel.fp8_mqa_logits import fp8_mqa_logits +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) class NSAIndexerInfer(BaseLayerInfer): def __init__(self, layer_idx, network_config, mode=[]): @@ -66,70 +68,37 @@ def get_indices(self, hidden_states: torch.Tensor, q_lora: torch.Tensor, q_fp8, q_scale = act_quant(q, self.block_size, self.scale_fmt) k_fp8, k_scale = act_quant(k, self.block_size, self.scale_fmt) - self._copy_ks_to_mem_cache(k_fp8, k_scale, infer_state.mem_index, infer_state.mem_manager) + destindex_copy_indexer_ks( + k_fp8.unsqueeze(1), + k_scale.unsqueeze(1), + infer_state.mem_index, + infer_state.indexer_ks_mem_manager.kv_buffer[self.layer_idx_] + ) weights = layer_weight.weights_proj_.mm(hidden_states) * self.index_n_heads_scale weights = weights.unsqueeze(-1) * q_scale - ks_buffer = infer_state.mem_manager.indexer_ks_mem_manager.kv_buffer[self.layer_idx_] - - k_fp8_list = [] - k_scale_list = [] - ks_list = [] - ke_list = [] - offset = 0 - for i in range(infer_state.batch_size): - q_len = infer_state.b_q_seq_len[i] - cache_len = infer_state.b_ready_cache_len[i] - mem_indexes = infer_state.req_manager.req_to_token_indexs[infer_state.b_req_idx[i], :cache_len+q_len] - k_fp8 = ks_buffer[mem_indexes, 0, :128].view(torch.float8_e4m3fn).contiguous() - k_scale = ks_buffer[mem_indexes, 0, 128:].view(torch.float32).contiguous() - ks = torch.full((q_len,), offset, dtype=torch.int32, device="cuda") - ke = ks + torch.arange(q_len, dtype=torch.int32, device="cuda") + 1 - k_fp8_list.append(k_fp8) - k_scale_list.append(k_scale) - ks_list.append(ks) - ke_list.append(ke) - offset += q_len - - k_fp8 = torch.cat(k_fp8_list, dim=0).view(torch.float8_e4m3fn) - k_scale = torch.cat(k_scale_list, dim=0).view(torch.float32).squeeze(-1) - kv_fp8 = (k_fp8, k_scale) - ks = torch.cat(ks_list, dim=0) - ke = torch.cat(ke_list, dim=0) - - logits = deep_gemm.fp8_mqa_logits( - q_fp8, - kv_fp8, - weights.squeeze(-1), - ks, - ke, - clean_logits=False, - ) - - return self.get_topk(logits, infer_state) - - def get_topk(self, logits, infer_state: Deepseek3_2FlashAttentionStateInfo): - topk_indices_list = [] - offset = 0 - - for i in range(infer_state.batch_size): - q_len = infer_state.b_q_seq_len[i] - cache_len = infer_state.b_ready_cache_len[i] - end_pos = q_len + cache_len - # Slice logits for this batch (both query and sequence dimensions) - batch_logits = logits[offset:offset + q_len, :end_pos] - topk_indices = batch_logits.topk(min(self.index_topk, end_pos), dim=-1)[1] - mem_indexes = infer_state.req_manager.req_to_token_indexs[infer_state.b_req_idx[i], :cache_len+q_len] - indices = torch.full((q_len, self.index_topk), -1, dtype=torch.int32, device="cuda") - for j in range(q_len): - indices[j, :topk_indices[j].shape[0]] = mem_indexes[topk_indices[j]] - topk_indices_list.append(indices) - offset += q_len + # Use pre-computed indexing structures from infer_state + mem_index = infer_state.mem_index + ks = infer_state.ks + ke = infer_state.ke + lengths = infer_state.lengths + page_table_1 = infer_state.page_table_size_1 - topk_indices_ = torch.cat(topk_indices_list, dim=0) + # TODO + k_fp8_ = infer_state.indexer_ks_mem_manager.kv_buffer[self.layer_idx_][mem_index, :, :128].view(torch.float8_e4m3fn).squeeze(1).contiguous() + k_scale_ = infer_state.indexer_ks_mem_manager.kv_buffer[self.layer_idx_][mem_index, :, 128:].view(torch.float32)[:, 0, 0].contiguous() - return topk_indices_ + logits = deep_gemm.fp8_mqa_logits(q_fp8, (k_fp8_, k_scale_), weights.squeeze(-1), ks, ke) + + # 返回 : [seq_q_len, topk] 无效的位置使用-1填充 + return fast_topk_transform_fused( + score=logits, # [seq_len_q, seq_len_kv] + lengths=lengths, # [seq_len_q] + page_table_size_1=page_table_1, # [seq_len_q, max(lengths)] 无效的使用0填充 + cu_seqlens_q=infer_state.cu_seqlens_q, # [seq_len_q + 1] + topk=self.index_topk, + ) def get_k_float32_from_buffer(self, buffer: torch.Tensor): @@ -152,8 +121,9 @@ def _rotate_activation(x: torch.Tensor) -> torch.Tensor: def _get_q_k_bf16(self, hidden_states: torch.Tensor, q_lora: torch.Tensor, infer_state: Deepseek3_2FlashAttentionStateInfo, layer_weight: NSAIndexerWeight): q = layer_weight.wq_b_proj_.mm(q_lora).view(-1, self.index_n_heads, self.index_head_dim) - k = layer_weight.wk_proj_.mm(hidden_states) + + # TODO k = F.layer_norm( k.float(), (self.index_head_dim,), layer_weight.k_norm_.weight, layer_weight.k_norm_.bias, self.eps ).type_as(k) @@ -168,17 +138,3 @@ def _get_q_k_bf16(self, hidden_states: torch.Tensor, q_lora: torch.Tensor, q = self._rotate_activation(q) k = self._rotate_activation(k) return q, k - - def _copy_ks_to_mem_cache(self, k_fp8, k_scale, mem_index, mem_manager: Deepseek3_2MemoryManager): - # k_fp8 : [seq_len, 128] torch.fp8_e4m3 - # k_scale : [seq_len, 1] torch.float32 - # mem_index : [seq_len] torch.int32 - # buffer : [10000000, 1, 132] torch.uint8 - buffer = mem_manager.indexer_ks_mem_manager.kv_buffer[self.layer_idx_] - destindex_copy_indexer_ks( - k_fp8.unsqueeze(1), # Add head dimension: [seq_len, 1, 128] - k_scale.unsqueeze(1), # Add head dimension: [seq_len, 1, 1] - mem_index, - buffer - ) - return \ No newline at end of file diff --git a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py index 188ab8b4aa..ed351312f4 100644 --- a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py @@ -82,10 +82,9 @@ def _nsa_context_attention_kernel( q_nope, q_rope = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1) q_all = torch.cat([q_nope, q_rope], dim=-1) - mla_out, _, _ = flash_mla_sparse_fwd( - q=q_all, - kv=infer_state.mem_manager.kv_buffer[self.layer_num_], + q=q_all, # [seq_len_q, q_num_head, qk_dim] + kv=infer_state.mem_manager.kv_buffer[self.layer_num_], # [size, 1, qk_dim] indices=self.topk_indices.unsqueeze(1), sm_scale=self.softmax_scale, d_v=self.kv_lora_rank, @@ -100,15 +99,16 @@ def _nsa_token_attention_kernel( kv = infer_state.mem_manager.kv_buffer[self.layer_num_] k_rope = kv[:, :, -self.qk_rope_head_dim :].reshape(-1, 1, 1, self.qk_rope_head_dim) kv_nope = kv[:, :, : -self.qk_rope_head_dim].reshape(-1, 1, 1, self.kv_lora_rank) + o_tensor = flash_attn_with_kvcache( - q=q_rope, - k_cache=k_rope, - v_cache=kv_nope, - qv=q_nope, - page_table=self.topk_indices, - cache_seqlens=infer_state.nsa_cache_seqlens, - cu_seqlens_q=infer_state.cu_seqlens_q, - cu_seqlens_k_new=infer_state.nsa_cu_seqlens_k, + q=q_rope, # (q_seqlen, nheads, qk_headdim) + k_cache=k_rope, # (kv_size, 1, 1, qk_head_dim) + v_cache=kv_nope, # (kv_size, 1, 1, kv_lora_rank) + qv=q_nope, # (q_seqlen, nheads, kv_lora_rank) + page_table=self.topk_indices, # (q_seqlen, max_seq_len) + cache_seqlens=infer_state.nsa_cache_seqlens, # (q_seqlen) # 表示当前kv长度,用于读取page_table. + cu_seqlens_q=infer_state.cu_seqlens_q, # (batch_size+1) [0,1] + cu_seqlens_k_new=infer_state.nsa_cu_seqlens_k, #(batch_size+1) [0,9] max_seqlen_q=infer_state.max_q_seq_len, softmax_scale=self.softmax_scale, causal=True, diff --git a/lightllm/models/deepseek3_2/model.py b/lightllm/models/deepseek3_2/model.py index ad7f70550f..b800944886 100644 --- a/lightllm/models/deepseek3_2/model.py +++ b/lightllm/models/deepseek3_2/model.py @@ -5,7 +5,7 @@ from lightllm.utils.envs_utils import get_env_start_args from lightllm.models.deepseek3_2.infer_struct import Deepseek3_2FlashAttentionStateInfo from lightllm.models.deepseek3_2.mem_manager import Deepseek3_2MemoryManager, Deepseek3_2FP8KVMemoryManager -@ModelRegistry(["deepseek_v32"]) +# @ModelRegistry(["deepseek_v32"]) class Deepseek3_2TpPartModel(Deepseek2TpPartModel): # weight class transformer_weight_class = Deepseek3_2TransformerLayerWeight @@ -21,6 +21,9 @@ def __init__(self, kvargs): self.index_topk = self.config["index_topk"] return + def _init_inferstate_cls(self): + self.infer_state_class = Deepseek3_2FlashAttentionStateInfo + def _init_mem_manager(self): manager_class = Deepseek3_2MemoryManager if "triton_fp8kv" in self.mode: diff --git a/lightllm/models/deepseek3_2/triton_kernel/fp8_mqa_logits.py b/lightllm/models/deepseek3_2/triton_kernel/fp8_mqa_logits.py index e69de29bb2..2fc92662af 100644 --- a/lightllm/models/deepseek3_2/triton_kernel/fp8_mqa_logits.py +++ b/lightllm/models/deepseek3_2/triton_kernel/fp8_mqa_logits.py @@ -0,0 +1,139 @@ +import triton +import triton.language as tl +import torch + + +@triton.jit +def _fp8_paged_mqa_logits_kernel( + Q_ptr, KV_ptr, KVScale_ptr, Weights_ptr, MemIndex_ptr, + CuSeqlenKs_ptr, CuSeqlenKe_ptr, Output_ptr, + seq_len, seq_len_kv, num_heads, head_dim, + stride_q_seq, stride_q_head, stride_q_dim, + stride_kv_pool, stride_kv_dim, + stride_w_seq, stride_w_head, + stride_o_seq, stride_o_kv, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_D: tl.constexpr, +): + + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + # Compute the range of seq positions this block handles + start_m = pid_m * BLOCK_SIZE_M + start_n = pid_n * BLOCK_SIZE_N + + # Offset arrays for this block + offs_m = start_m + tl.arange(0, BLOCK_SIZE_M) + offs_n = start_n + tl.arange(0, BLOCK_SIZE_N) + + # Initialize accumulator for logits + logits = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Create masks + mask_m = offs_m < seq_len + mask_n = offs_n < seq_len_kv + + # Load mem_indices for the KV positions + mem_indices = tl.load(MemIndex_ptr + offs_n, mask=mask_n, other=0) + + # Load scales for K + scales = tl.load(KVScale_ptr + mem_indices, mask=mask_n, other=1.0) + + # Loop over all heads + for h in range(num_heads): + # Load weights for this head + weights = tl.load(Weights_ptr + offs_m * stride_w_seq + h * stride_w_head, + mask=mask_m, other=0.0) + + # Initialize score accumulator for this head + score = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Loop over head_dim in blocks + for d_block in range(tl.cdiv(head_dim, BLOCK_SIZE_D)): + d_start = d_block * BLOCK_SIZE_D + offs_d = d_start + tl.arange(0, BLOCK_SIZE_D) + mask_d = offs_d < head_dim + + # Load Q for this head and dimension block + # Q shape: (seq_len, num_heads, head_dim) + q_ptrs = Q_ptr + offs_m[:, None] * stride_q_seq + h * stride_q_head + offs_d[None, :] * stride_q_dim + mask_q = (offs_m[:, None] < seq_len) & mask_d[None, :] + q = tl.load(q_ptrs, mask=mask_q, other=0.0).to(tl.float32) + + # Load K for this dimension block + # KV shape: (pool_size, head_dim) as FP8 data + k_ptrs = KV_ptr + mem_indices[:, None] * stride_kv_pool + offs_d[None, :] * stride_kv_dim + mask_k = mask_n[:, None] & mask_d[None, :] + k = tl.load(k_ptrs, mask=mask_k, other=0.0).to(tl.float32) + + # Apply scale to K (scale is per-row of K) + k = k * scales[:, None] + + # Compute partial dot product: q @ k.T + # q: (BLOCK_SIZE_M, BLOCK_SIZE_D), k: (BLOCK_SIZE_N, BLOCK_SIZE_D) + # score: (BLOCK_SIZE_M, BLOCK_SIZE_N) + score += tl.dot(q, tl.trans(k)) + + # Apply ReLU to score + score = tl.maximum(score, 0.0) + + # Multiply by weights and accumulate to logits + logits += score * weights[:, None] + + # Apply mask based on cu_seqlen_ks and cu_seqlen_ke + mask_ks = tl.load(CuSeqlenKs_ptr + offs_m, mask=mask_m, other=0) + mask_ke = tl.load(CuSeqlenKe_ptr + offs_m, mask=mask_m, other=seq_len_kv) + + mask_lo = offs_n[None, :] >= mask_ks[:, None] + mask_hi = offs_n[None, :] < mask_ke[:, None] + mask_valid = mask_lo & mask_hi & mask_m[:, None] & mask_n[None, :] + + # Apply mask (-inf for masked positions) + logits = tl.where(mask_valid, logits, float('-inf')) + + # Store output + out_ptrs = Output_ptr + offs_m[:, None] * stride_o_seq + offs_n[None, :] * stride_o_kv + mask_out = (offs_m[:, None] < seq_len) & (offs_n[None, :] < seq_len_kv) + tl.store(out_ptrs, logits, mask=mask_out) + + +def fp8_paged_mqa_logits( + q: torch.Tensor, + kv: torch.Tensor, + kv_scale: torch.Tensor, + weights: torch.Tensor, + mem_index: torch.Tensor, + cu_seqlen_ks: torch.Tensor, + cu_seqlen_ke: torch.Tensor, + out: torch.Tensor = None +) -> torch.Tensor: + seq_len, num_heads, head_dim = q.shape + seq_len_kv = mem_index.shape[0] + + if out is None: + output = torch.empty((seq_len, seq_len_kv), device=q.device, dtype=torch.float32) + else: + output = out + + BLOCK_SIZE_M = 16 + BLOCK_SIZE_N = 64 + BLOCK_SIZE_D = 128 + + grid = (triton.cdiv(seq_len, BLOCK_SIZE_M), triton.cdiv(seq_len_kv, BLOCK_SIZE_N)) + + _fp8_paged_mqa_logits_kernel[grid]( + q, kv, kv_scale, weights, mem_index, + cu_seqlen_ks, cu_seqlen_ke, output, + seq_len, seq_len_kv, num_heads, head_dim, + q.stride(0), q.stride(1), q.stride(2), + kv.stride(0), kv.stride(1), + weights.stride(0), weights.stride(1), + output.stride(0), output.stride(1), + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_D=BLOCK_SIZE_D, + ) + + return output \ No newline at end of file From 3f3a6564bb1917c0ccf4a3629d7e03fcbe1a3bc8 Mon Sep 17 00:00:00 2001 From: sufubao Date: Mon, 10 Nov 2025 04:41:25 +0000 Subject: [PATCH 07/58] fix --- .../layer_infer/nsa_indexer_layer_inder.py | 12 +- .../triton_kernel/extract_indexer_ks.py | 156 ++++++++++++++++++ 2 files changed, 160 insertions(+), 8 deletions(-) create mode 100644 lightllm/models/deepseek3_2/triton_kernel/extract_indexer_ks.py diff --git a/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py b/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py index d7444e9187..173196bf42 100644 --- a/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py +++ b/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py @@ -10,6 +10,8 @@ from lightllm.models.deepseek3_2.triton_kernel.act_quant import act_quant from lightllm.models.deepseek3_2.mem_manager import Deepseek3_2MemoryManager from lightllm.models.deepseek3_2.triton_kernel.destindex_copy_indexer_ks import destindex_copy_indexer_ks +from lightllm.models.deepseek3_2.triton_kernel.extract_indexer_ks import extract_indexer_ks +from lightllm.models.bloom.triton_kernel.layernorm import layernorm_forward from lightllm.utils.log_utils import init_logger logger = init_logger(__name__) @@ -78,16 +80,13 @@ def get_indices(self, hidden_states: torch.Tensor, q_lora: torch.Tensor, weights = layer_weight.weights_proj_.mm(hidden_states) * self.index_n_heads_scale weights = weights.unsqueeze(-1) * q_scale - # Use pre-computed indexing structures from infer_state mem_index = infer_state.mem_index ks = infer_state.ks ke = infer_state.ke lengths = infer_state.lengths page_table_1 = infer_state.page_table_size_1 - # TODO - k_fp8_ = infer_state.indexer_ks_mem_manager.kv_buffer[self.layer_idx_][mem_index, :, :128].view(torch.float8_e4m3fn).squeeze(1).contiguous() - k_scale_ = infer_state.indexer_ks_mem_manager.kv_buffer[self.layer_idx_][mem_index, :, 128:].view(torch.float32)[:, 0, 0].contiguous() + k_fp8_, k_scale_ = extract_indexer_ks(infer_state.indexer_ks_mem_manager.kv_buffer[self.layer_idx_], mem_index) logits = deep_gemm.fp8_mqa_logits(q_fp8, (k_fp8_, k_scale_), weights.squeeze(-1), ks, ke) @@ -123,10 +122,7 @@ def _get_q_k_bf16(self, hidden_states: torch.Tensor, q_lora: torch.Tensor, q = layer_weight.wq_b_proj_.mm(q_lora).view(-1, self.index_n_heads, self.index_head_dim) k = layer_weight.wk_proj_.mm(hidden_states) - # TODO - k = F.layer_norm( - k.float(), (self.index_head_dim,), layer_weight.k_norm_.weight, layer_weight.k_norm_.bias, self.eps - ).type_as(k) + k = layernorm_forward(k, layer_weight.k_norm_.weight, layer_weight.k_norm_.bias, self.eps) rotary_emb_fwd( q[:, :, : self.qk_rope_head_dim], diff --git a/lightllm/models/deepseek3_2/triton_kernel/extract_indexer_ks.py b/lightllm/models/deepseek3_2/triton_kernel/extract_indexer_ks.py new file mode 100644 index 0000000000..e97454ba2e --- /dev/null +++ b/lightllm/models/deepseek3_2/triton_kernel/extract_indexer_ks.py @@ -0,0 +1,156 @@ +import torch +import triton +import triton.language as tl +import numpy + + +@triton.jit +def _fwd_kernel_extract_indexer_ks( + buffer_fp8, + buffer_scale, + mem_index, + k_fp8_out, + k_scale_out, + stride_buffer_fp8_bs, + stride_buffer_fp8_h, + stride_buffer_fp8_d, + stride_buffer_scale_bs, + stride_buffer_scale_h, + stride_buffer_scale_d, + stride_k_fp8_out_bs, + stride_k_fp8_out_d, + stride_k_scale_out_bs, + BLOCK_DMODEL: tl.constexpr, +): + cur_index = tl.program_id(0) + + # Load the memory index + mem_idx = tl.load(mem_index + cur_index).to(tl.int64) + + # Load k_fp8 data from buffer_fp8[mem_idx, 0, :] + offs_d = tl.arange(0, BLOCK_DMODEL) + k_fp8_ptrs = buffer_fp8 + mem_idx * stride_buffer_fp8_bs + 0 * stride_buffer_fp8_h + offs_d * stride_buffer_fp8_d + k_fp8_data = tl.load(k_fp8_ptrs) + + # Load k_scale data from buffer_scale[mem_idx, 0, 0] + k_scale_ptr = buffer_scale + mem_idx * stride_buffer_scale_bs + 0 * stride_buffer_scale_h + 0 * stride_buffer_scale_d + k_scale_data = tl.load(k_scale_ptr) + + # Store k_fp8 output + k_fp8_out_ptrs = k_fp8_out + cur_index * stride_k_fp8_out_bs + offs_d * stride_k_fp8_out_d + tl.store(k_fp8_out_ptrs, k_fp8_data) + + # Store k_scale output + k_scale_out_ptr = k_scale_out + cur_index * stride_k_scale_out_bs + tl.store(k_scale_out_ptr, k_scale_data) + + +@torch.no_grad() +def extract_indexer_ks(buffer, mem_index): + """ + Extract k_fp8 and k_scale from the indexer memory buffer using Triton kernel. + + Args: + buffer: Memory buffer of shape [total_tokens, heads, 132] with dtype uint8 + mem_index: Indices tensor of shape [seq_len] with dtype int32/int64 + + Returns: + k_fp8: Tensor of shape [seq_len, 128] with dtype float8_e4m3fn + k_scale: Tensor of shape [seq_len] with dtype float32 + """ + seq_len = mem_index.shape[0] + assert buffer.shape[2] == 132, f"buffer dim should be 132, got {buffer.shape[2]}" + + # Reinterpret buffer as the appropriate types for Triton + buffer_fp8 = buffer[:, :, :128].view(torch.float8_e4m3fn) + buffer_scale = buffer[:, :, 128:132].view(torch.float32)[:, :, :1] + + # Prepare output tensors + k_fp8_out = torch.empty((seq_len, 128), dtype=torch.float8_e4m3fn, device=buffer.device) + k_scale_out = torch.empty((seq_len,), dtype=torch.float32, device=buffer.device) + + BLOCK_DMODEL = 128 + grid = (seq_len,) + num_warps = 1 + + _fwd_kernel_extract_indexer_ks[grid]( + buffer_fp8, + buffer_scale, + mem_index, + k_fp8_out, + k_scale_out, + buffer_fp8.stride(0), + buffer_fp8.stride(1), + buffer_fp8.stride(2), + buffer_scale.stride(0), + buffer_scale.stride(1), + buffer_scale.stride(2), + k_fp8_out.stride(0), + k_fp8_out.stride(1), + k_scale_out.stride(0), + BLOCK_DMODEL=BLOCK_DMODEL, + num_warps=num_warps, + num_stages=1, + ) + + return k_fp8_out, k_scale_out + + +def test(): + # Test parameters similar to the usage in nsa_indexer_layer_inder.py + B, N_CTX, H = 4, 1024, 1 # batch_size, seq_len, heads (always 1 for this) + seq_len = 50 # number of tokens to extract + dtype_fp8 = torch.float8_e4m3fn + dtype_scale = torch.float32 + + # Create test buffer [total_tokens, heads, 132] as uint8 + buffer = torch.zeros((B * N_CTX, H, 132), dtype=torch.uint8).cuda() + + # Fill with test data - simulate what destindex_copy_indexer_ks does + test_indices = torch.randint(0, B * N_CTX, (seq_len,), dtype=torch.int32).cuda() + # Generate fp8 data by converting from float32 + test_k_fp8_fp32 = torch.randn((seq_len, 128), dtype=torch.float32).cuda() + test_k_fp8 = test_k_fp8_fp32.to(dtype_fp8) + test_k_scale = torch.randn((seq_len,), dtype=dtype_scale).cuda() + + # Manually populate buffer as destindex_copy_indexer_ks would + for i in range(seq_len): + dest_idx = test_indices[i].item() + # Store fp8 data + buffer[dest_idx, 0, :128] = test_k_fp8[i].view(torch.uint8) + # Store scale data (4 bytes) - need to convert float32 to bytes + scale_bytes = test_k_scale[i].cpu().numpy().tobytes() + scale_bytes_np = numpy.frombuffer(scale_bytes, dtype=numpy.uint8) + buffer[dest_idx, 0, 128:132] = torch.from_numpy(scale_bytes_np).to(buffer.device) + + # Call our extraction function + extracted_fp8, extracted_scale = extract_indexer_ks(buffer, test_indices) + + # Verify results + print(f"Original k_fp8 shape: {test_k_fp8.shape}, dtype: {test_k_fp8.dtype}") + print(f"Extracted k_fp8 shape: {extracted_fp8.shape}, dtype: {extracted_fp8.dtype}") + print(f"Original k_scale shape: {test_k_scale.shape}, dtype: {test_k_scale.dtype}") + print(f"Extracted k_scale shape: {extracted_scale.shape}, dtype: {extracted_scale.dtype}") + + # Check if extraction matches (convert fp8 to float32 for comparison) + # Use higher tolerance for fp8 due to quantization precision + fp8_match = torch.allclose(test_k_fp8_fp32, extracted_fp8.float(), atol=0.1, rtol=0.1) + scale_match = torch.allclose(test_k_scale, extracted_scale, atol=1e-6) + + print(f"FP8 data matches: {fp8_match}") + print(f"Scale data matches: {scale_match}") + + if fp8_match and scale_match: + print("All tests passed!") + else: + print("Test failed!") + if not fp8_match: + print("First few fp8 values:") + print(f"Original: {test_k_fp8_fp32[0, :5]}") + print(f"Extracted: {extracted_fp8.float()[0, :5]}") + if not scale_match: + print(f"Max scale diff: {torch.max(torch.abs(test_k_scale - extracted_scale))}") + + +if __name__ == "__main__": + test() From 19c91289bf32b23afe3af23b35d5ad6fea1692ab Mon Sep 17 00:00:00 2001 From: sufubao Date: Mon, 10 Nov 2025 04:49:38 +0000 Subject: [PATCH 08/58] fix --- lightllm/models/deepseek3_2/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightllm/models/deepseek3_2/model.py b/lightllm/models/deepseek3_2/model.py index b800944886..8f1ba85cf2 100644 --- a/lightllm/models/deepseek3_2/model.py +++ b/lightllm/models/deepseek3_2/model.py @@ -5,7 +5,7 @@ from lightllm.utils.envs_utils import get_env_start_args from lightllm.models.deepseek3_2.infer_struct import Deepseek3_2FlashAttentionStateInfo from lightllm.models.deepseek3_2.mem_manager import Deepseek3_2MemoryManager, Deepseek3_2FP8KVMemoryManager -# @ModelRegistry(["deepseek_v32"]) +@ModelRegistry(["deepseek_v32"]) class Deepseek3_2TpPartModel(Deepseek2TpPartModel): # weight class transformer_weight_class = Deepseek3_2TransformerLayerWeight From 303ca1955262ce30874a5cdd52fda58ef2a1578b Mon Sep 17 00:00:00 2001 From: sufubao Date: Mon, 10 Nov 2025 04:51:11 +0000 Subject: [PATCH 09/58] fix --- .../models/deepseek3_2/layer_infer/transformer_layer_infer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py index ed351312f4..5fc33d5aa6 100644 --- a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py @@ -1,6 +1,5 @@ from functools import partial from typing import override -from venv import logger import torch from sgl_kernel.flash_mla import flash_mla_sparse_fwd From 3b3748af13142ef3c8f07611f24dc13734cf6c73 Mon Sep 17 00:00:00 2001 From: sufubao Date: Mon, 10 Nov 2025 04:53:59 +0000 Subject: [PATCH 10/58] fix --- lightllm/models/deepseek3_2/infer_struct.py | 2 -- .../layer_infer/nsa_indexer_layer_inder.py | 9 ++++----- .../layer_infer/transformer_layer_infer.py | 20 +++++++++---------- 3 files changed, 14 insertions(+), 17 deletions(-) diff --git a/lightllm/models/deepseek3_2/infer_struct.py b/lightllm/models/deepseek3_2/infer_struct.py index 8e5eb0b819..e955c3bbde 100644 --- a/lightllm/models/deepseek3_2/infer_struct.py +++ b/lightllm/models/deepseek3_2/infer_struct.py @@ -62,8 +62,6 @@ def _init_nsa_indexing_structures(self): offset += seq_len self.mem_index = torch.cat(mem_index_list, dim=0) - # ks : [seq_len_q] 标志kv的起始位置 - # ke : [seq_len_q] 标志kv的结束位置 self.ks = torch.cat(ks_list, dim=0) self.ke = torch.cat(ke_list, dim=0) self.lengths = torch.cat(lengths_list, dim=0) \ No newline at end of file diff --git a/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py b/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py index 173196bf42..d5032e72f1 100644 --- a/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py +++ b/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py @@ -90,12 +90,11 @@ def get_indices(self, hidden_states: torch.Tensor, q_lora: torch.Tensor, logits = deep_gemm.fp8_mqa_logits(q_fp8, (k_fp8_, k_scale_), weights.squeeze(-1), ks, ke) - # 返回 : [seq_q_len, topk] 无效的位置使用-1填充 return fast_topk_transform_fused( - score=logits, # [seq_len_q, seq_len_kv] - lengths=lengths, # [seq_len_q] - page_table_size_1=page_table_1, # [seq_len_q, max(lengths)] 无效的使用0填充 - cu_seqlens_q=infer_state.cu_seqlens_q, # [seq_len_q + 1] + score=logits, + lengths=lengths, + page_table_size_1=page_table_1, + cu_seqlens_q=infer_state.cu_seqlens_q, topk=self.index_topk, ) diff --git a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py index 5fc33d5aa6..5b550ab09a 100644 --- a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py @@ -82,8 +82,8 @@ def _nsa_context_attention_kernel( q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1) q_all = torch.cat([q_nope, q_rope], dim=-1) mla_out, _, _ = flash_mla_sparse_fwd( - q=q_all, # [seq_len_q, q_num_head, qk_dim] - kv=infer_state.mem_manager.kv_buffer[self.layer_num_], # [size, 1, qk_dim] + q=q_all, + kv=infer_state.mem_manager.kv_buffer[self.layer_num_], indices=self.topk_indices.unsqueeze(1), sm_scale=self.softmax_scale, d_v=self.kv_lora_rank, @@ -100,14 +100,14 @@ def _nsa_token_attention_kernel( kv_nope = kv[:, :, : -self.qk_rope_head_dim].reshape(-1, 1, 1, self.kv_lora_rank) o_tensor = flash_attn_with_kvcache( - q=q_rope, # (q_seqlen, nheads, qk_headdim) - k_cache=k_rope, # (kv_size, 1, 1, qk_head_dim) - v_cache=kv_nope, # (kv_size, 1, 1, kv_lora_rank) - qv=q_nope, # (q_seqlen, nheads, kv_lora_rank) - page_table=self.topk_indices, # (q_seqlen, max_seq_len) - cache_seqlens=infer_state.nsa_cache_seqlens, # (q_seqlen) # 表示当前kv长度,用于读取page_table. - cu_seqlens_q=infer_state.cu_seqlens_q, # (batch_size+1) [0,1] - cu_seqlens_k_new=infer_state.nsa_cu_seqlens_k, #(batch_size+1) [0,9] + q=q_rope, + k_cache=k_rope, + v_cache=kv_nope, + qv=q_nope, + page_table=self.topk_indices, + cache_seqlens=infer_state.nsa_cache_seqlens, + cu_seqlens_q=infer_state.cu_seqlens_q, + cu_seqlens_k_new=infer_state.nsa_cu_seqlens_k, max_seqlen_q=infer_state.max_q_seq_len, softmax_scale=self.softmax_scale, causal=True, From 5d8119ef1664235eea714a6c62879f89c6bb8f85 Mon Sep 17 00:00:00 2001 From: sufubao Date: Mon, 10 Nov 2025 13:57:17 +0000 Subject: [PATCH 11/58] can run without cudagraph --- lightllm/models/deepseek3_2/infer_struct.py | 6 +- .../layer_infer/nsa_indexer_layer_inder.py | 24 +- .../layer_infer/transformer_layer_infer.py | 6 +- .../destindex_copy_indexer_ks.py | 354 ++++++++++----- .../triton_kernel/extract_indexer_ks.py | 409 ++++++++++++------ 5 files changed, 547 insertions(+), 252 deletions(-) diff --git a/lightllm/models/deepseek3_2/infer_struct.py b/lightllm/models/deepseek3_2/infer_struct.py index e955c3bbde..c122c6a7e7 100644 --- a/lightllm/models/deepseek3_2/infer_struct.py +++ b/lightllm/models/deepseek3_2/infer_struct.py @@ -40,7 +40,7 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor): def _init_nsa_indexing_structures(self): """Pre-compute ks, ke, lengths, and page_table_size_1 for NSA indexer""" - mem_index_list = [] + req_all_mem_index_list = [] ks_list = [] ke_list = [] lengths_list = [] @@ -52,7 +52,7 @@ def _init_nsa_indexing_structures(self): seq_len = self.b_seq_len[i] q_seq_len = self.b_q_seq_len[i] mem_index = self.req_manager.req_to_token_indexs[i, :seq_len] - mem_index_list.append(mem_index) + req_all_mem_index_list.append(mem_index) self.page_table_size_1[i, :seq_len] = mem_index ks = torch.zeros(q_seq_len, dtype=torch.int, device='cuda') + offset ke = torch.arange(q_seq_len, dtype=torch.int, device='cuda') + offset + 1 @@ -61,7 +61,7 @@ def _init_nsa_indexing_structures(self): lengths_list.append(torch.arange(seq_len - q_seq_len + 1, seq_len + 1, dtype=torch.int, device='cuda')) offset += seq_len - self.mem_index = torch.cat(mem_index_list, dim=0) + self.req_all_mem_index = torch.cat(req_all_mem_index_list, dim=0) self.ks = torch.cat(ks_list, dim=0) self.ke = torch.cat(ke_list, dim=0) self.lengths = torch.cat(lengths_list, dim=0) \ No newline at end of file diff --git a/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py b/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py index d5032e72f1..df045dd2d2 100644 --- a/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py +++ b/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py @@ -71,8 +71,8 @@ def get_indices(self, hidden_states: torch.Tensor, q_lora: torch.Tensor, k_fp8, k_scale = act_quant(k, self.block_size, self.scale_fmt) destindex_copy_indexer_ks( - k_fp8.unsqueeze(1), - k_scale.unsqueeze(1), + k_fp8, + k_scale, infer_state.mem_index, infer_state.indexer_ks_mem_manager.kv_buffer[self.layer_idx_] ) @@ -80,13 +80,16 @@ def get_indices(self, hidden_states: torch.Tensor, q_lora: torch.Tensor, weights = layer_weight.weights_proj_.mm(hidden_states) * self.index_n_heads_scale weights = weights.unsqueeze(-1) * q_scale - mem_index = infer_state.mem_index ks = infer_state.ks ke = infer_state.ke lengths = infer_state.lengths page_table_1 = infer_state.page_table_size_1 - k_fp8_, k_scale_ = extract_indexer_ks(infer_state.indexer_ks_mem_manager.kv_buffer[self.layer_idx_], mem_index) + # Use efficient Triton kernel to extract FP8 keys and scales from buffer + k_fp8_, k_scale_ = extract_indexer_ks( + infer_state.indexer_ks_mem_manager.kv_buffer[self.layer_idx_], + infer_state.req_all_mem_index + ) logits = deep_gemm.fp8_mqa_logits(q_fp8, (k_fp8_, k_scale_), weights.squeeze(-1), ks, ke) @@ -99,12 +102,6 @@ def get_indices(self, hidden_states: torch.Tensor, q_lora: torch.Tensor, ) - def get_k_float32_from_buffer(self, buffer: torch.Tensor): - k_fp8 = buffer[:, :, :128].view(torch.float8_e4m3fn) - k_scale = buffer[:, :, 128:].view(torch.float32)[:, :, :1] - k_float32 = k_fp8.float() * k_scale - return k_float32 - @staticmethod def _rotate_activation(x: torch.Tensor) -> torch.Tensor: assert x.dtype == torch.bfloat16 @@ -121,8 +118,11 @@ def _get_q_k_bf16(self, hidden_states: torch.Tensor, q_lora: torch.Tensor, q = layer_weight.wq_b_proj_.mm(q_lora).view(-1, self.index_n_heads, self.index_head_dim) k = layer_weight.wk_proj_.mm(hidden_states) - k = layernorm_forward(k, layer_weight.k_norm_.weight, layer_weight.k_norm_.bias, self.eps) - + # TODO + k = F.layer_norm( + k.float(), (self.index_head_dim,), layer_weight.k_norm_.weight, layer_weight.k_norm_.bias, self.eps + ).type_as(k) + rotary_emb_fwd( q[:, :, : self.qk_rope_head_dim], k[:, None, : self.qk_rope_head_dim], diff --git a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py index 5b550ab09a..df52204270 100644 --- a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py @@ -64,7 +64,11 @@ def _get_qkv( @override def _bind_attention(self): - super()._bind_attention() + if "triton_fp8kv" in self.mode: + self._copy_kv_to_mem_cache = partial(Deepseek2TransformerLayerInfer._copy_kv_to_mem_cache_fp8, self) + else: + self._copy_kv_to_mem_cache = partial(Deepseek2TransformerLayerInfer._copy_kv_to_mem_cache_normal, self) + self._context_attention_kernel = partial(Deepseek3_2TransformerLayerInfer._nsa_context_attention_kernel, self) self._token_attention_kernel = partial(Deepseek3_2TransformerLayerInfer._nsa_token_attention_kernel, self) pass diff --git a/lightllm/models/deepseek3_2/triton_kernel/destindex_copy_indexer_ks.py b/lightllm/models/deepseek3_2/triton_kernel/destindex_copy_indexer_ks.py index a098795fb7..46095bfb75 100644 --- a/lightllm/models/deepseek3_2/triton_kernel/destindex_copy_indexer_ks.py +++ b/lightllm/models/deepseek3_2/triton_kernel/destindex_copy_indexer_ks.py @@ -6,132 +6,270 @@ @triton.jit def _fwd_kernel_destindex_copy_indexer_ks( - k_fp8, - k_scale, - mem_index, - buffer_fp8, - buffer_scale, - stride_k_fp8_bs, - stride_k_fp8_h, - stride_k_fp8_d, - stride_k_scale_bs, - stride_k_scale_h, - stride_k_scale_d, - stride_buffer_fp8_bs, - stride_buffer_fp8_h, - stride_buffer_fp8_d, - stride_buffer_scale_bs, - stride_buffer_scale_h, - stride_buffer_scale_d, - head_num, + K_fp8, + K_scale, + DestLoc, + O_buffer, + stride_k_bs, + stride_k_d, + stride_scale_bs, + stride_scale_d, + stride_o_bs, + stride_o_h, + stride_o_d, BLOCK_DMODEL: tl.constexpr, - BLOCK_HEAD: tl.constexpr, ): + """ + Triton kernel to copy FP8 K values and their scales to an indexed output buffer. + + This kernel reads FP8 key values (128 dims) and their float32 scale values, + then writes them to a compact buffer format where each entry contains: + - Bytes 0-127: FP8 key values (128 bytes) + - Bytes 128-131: Float32 scale (4 bytes) + + The destination location for each source element is specified by DestLoc. + """ cur_index = tl.program_id(0) - offs_h = tl.arange(0, BLOCK_HEAD) offs_d = tl.arange(0, BLOCK_DMODEL) - - dest_index = tl.load(mem_index + cur_index).to(tl.int64) - - # Load k_fp8 data - k_fp8_ptrs = k_fp8 + cur_index * stride_k_fp8_bs + stride_k_fp8_h * offs_h[:, None] + stride_k_fp8_d * offs_d[None, :] - k_fp8_data = tl.load(k_fp8_ptrs, mask=offs_h[:, None] < head_num, other=0.0) - - # Load k_scale data - k_scale_ptrs = k_scale + cur_index * stride_k_scale_bs + stride_k_scale_h * offs_h[:, None] + stride_k_scale_d * tl.arange(0, 1)[None, :] - k_scale_data = tl.load(k_scale_ptrs, mask=offs_h[:, None] < head_num, other=0.0) - - # Store k_fp8 to buffer_fp8 - buffer_fp8_ptrs = buffer_fp8 + dest_index * stride_buffer_fp8_bs + stride_buffer_fp8_h * offs_h[:, None] + stride_buffer_fp8_d * offs_d[None, :] - tl.store(buffer_fp8_ptrs, k_fp8_data, mask=offs_h[:, None] < head_num) - - # Store k_scale to buffer_scale - buffer_scale_ptrs = buffer_scale + dest_index * stride_buffer_scale_bs + stride_buffer_scale_h * offs_h[:, None] + stride_buffer_scale_d * tl.arange(0, 1)[None, :] - tl.store(buffer_scale_ptrs, k_scale_data, mask=offs_h[:, None] < head_num) + + # Load destination index for this thread + dest_index = tl.load(DestLoc + cur_index).to(tl.int64) + + # Load K_fp8 (128 values) and K_scale (1 value) from source + k_fp8_ptrs = K_fp8 + cur_index * stride_k_bs + stride_k_d * offs_d + k_fp8 = tl.load(k_fp8_ptrs) + + k_scale = tl.load(K_scale + cur_index * stride_scale_bs) + + # Store K_fp8 to O_buffer[:, 0, :128] + # Convert fp8 to uint8 through bitcast for storage in uint8 buffer + o_k_ptrs = O_buffer + dest_index * stride_o_bs + stride_o_d * offs_d + k_fp8_as_uint8 = k_fp8.to(tl.uint8, bitcast=True) + tl.store(o_k_ptrs, k_fp8_as_uint8) + + # Store K_scale to O_buffer[:, 0, 128:132] (4 bytes for float32) + # Convert float32 scale to 4 uint8 bytes using bitcast and bit manipulation + o_scale_ptr = O_buffer + dest_index * stride_o_bs + BLOCK_DMODEL * stride_o_d + scale_as_uint32 = k_scale.to(tl.float32, bitcast=True).to(tl.uint32, bitcast=True) + + # Store each byte of the float32 scale (little-endian) + for i in range(4): + byte_val = ((scale_as_uint32 >> (i * 8)) & 0xFF).to(tl.uint8) + tl.store(o_scale_ptr + i * stride_o_d, byte_val) + + return @torch.no_grad() -def destindex_copy_indexer_ks(k_fp8, k_scale, mem_index, buffer): - seq_len = mem_index.shape[0] - head_num = k_fp8.shape[1] - k_fp8_dim = k_fp8.shape[2] # Should be 128 for float8 - k_scale_dim = k_scale.shape[2] # Should be 1 +def destindex_copy_indexer_ks(K_fp8: torch.Tensor, K_scale: torch.Tensor, DestLoc: torch.Tensor, O_buffer: torch.Tensor): + """ + Copy FP8-quantized key values and their scales to indexed locations in a buffer. + + This function is used in the DeepSeek-V3.2 NSA (Neighbor-aware Sparse Attention) + mechanism to store compressed key representations in a memory buffer. Each key + is stored with its FP8 representation (128 bytes) followed by its float32 scale + (4 bytes), for a total of 132 bytes per key. + + Args: + K_fp8: [q_seq_len, 128] torch.fp8_e4m3fn + FP8-quantized key values + K_scale: [q_seq_len, 1] torch.float32 + Quantization scales for each key + DestLoc: [q_seq_len] torch.int32 + Destination indices in the output buffer + O_buffer: [large_size, 1, 132] torch.uint8 + Output buffer where keys and scales will be written. + Must be a uint8 tensor to allow mixed-type storage. + Format: [:, 0, :128] = FP8 keys, [:, 0, 128:132] = float32 scales - assert k_fp8.shape[1] == k_scale.shape[1] - assert k_fp8_dim == 128, f"k_fp8 dim should be 128, got {k_fp8_dim}" - assert k_scale_dim == 1, f"k_scale dim should be 1, got {k_scale_dim}" - assert buffer.shape[2] == 132, f"buffer dim should be 132, got {buffer.shape[2]}" # 128 + 4 bytes - - # Reinterpret buffer as the appropriate types for storing - buffer_fp8 = buffer[:, :, :128].view(torch.float8_e4m3fn) - buffer_scale = buffer[:, :, 128:132].view(torch.float32)[:, :, :1] - - BLOCK_HEAD = triton.next_power_of_2(head_num) + Returns: + None (modifies O_buffer in-place) + + Example: + >>> k_fp8 = torch.randn(50, 128).to(torch.float8_e4m3fn).cuda() + >>> k_scale = torch.randn(50, 1).cuda() + >>> dest_loc = torch.randint(0, 1024, (50,), dtype=torch.int32).cuda() + >>> o_buffer = torch.zeros(1024, 1, 132, dtype=torch.uint8).cuda() + >>> destindex_copy_indexer_ks(k_fp8, k_scale, dest_loc, o_buffer) + >>> # Now o_buffer[dest_loc] contains the packed k_fp8 and k_scale data + """ + seq_len = DestLoc.shape[0] + head_dim = K_fp8.shape[1] + + assert head_dim == 128, f"Expected head_dim=128, got {head_dim}" + assert K_scale.shape[0] == seq_len + assert O_buffer.shape[2] == 132, f"Expected O_buffer last dim=132, got {O_buffer.shape[2]}" + grid = (seq_len,) num_warps = 1 - + _fwd_kernel_destindex_copy_indexer_ks[grid]( - k_fp8, - k_scale, - mem_index, - buffer_fp8, - buffer_scale, - k_fp8.stride(0), - k_fp8.stride(1), - k_fp8.stride(2), - k_scale.stride(0), - k_scale.stride(1), - k_scale.stride(2), - buffer_fp8.stride(0), - buffer_fp8.stride(1), - buffer_fp8.stride(2), - buffer_scale.stride(0), - buffer_scale.stride(1), - buffer_scale.stride(2), - head_num, - BLOCK_DMODEL=k_fp8_dim, - BLOCK_HEAD=BLOCK_HEAD, + K_fp8, + K_scale, + DestLoc, + O_buffer, + K_fp8.stride(0), + K_fp8.stride(1), + K_scale.stride(0), + K_scale.stride(1), + O_buffer.stride(0), + O_buffer.stride(1), + O_buffer.stride(2), + BLOCK_DMODEL=head_dim, num_warps=num_warps, num_stages=1, ) return -def test(): +def test_destindex_copy_indexer_ks(): + """Test the destindex_copy_indexer_ks kernel""" import torch.nn.functional as F - - # Test parameters similar to the usage in nsa_indexer_layer_inder.py - B, N_CTX, H, K_DIM = 4, 1024, 8, 128 # batch_size, seq_len, heads, k_dim - seq_len = 50 # number of tokens to copy - dtype_fp8 = torch.float8_e4m3fn - dtype_scale = torch.float32 - - # Create test data - k_fp8 = torch.randn((seq_len, H, K_DIM), dtype=dtype_fp8).cuda() - k_scale = torch.randn((seq_len, H, 1), dtype=dtype_scale).cuda() - mem_index = torch.randint(0, B * N_CTX, (seq_len,), dtype=torch.int32).cuda() - - # Create buffer [total_tokens, heads, 132] - buffer = torch.zeros((B * N_CTX, H, 132), dtype=torch.uint8).cuda() - - # Call the function - destindex_copy_indexer_ks(k_fp8, k_scale, mem_index, buffer) - - # Verify results - for i in range(seq_len): - dest_idx = mem_index[i].item() - # Check k_fp8 part - stored_fp8 = buffer[dest_idx, :, :128].view(dtype_fp8) - expected_fp8 = k_fp8[i] - assert torch.allclose(stored_fp8, expected_fp8, atol=1e-6), f"FP8 mismatch at index {i}" - - # Check k_scale part - stored_scale = buffer[dest_idx, :, 128:].view(dtype_scale)[:, :1] - expected_scale = k_scale[i] - assert torch.allclose(stored_scale, expected_scale, atol=1e-6), f"Scale mismatch at index {i}" - - print("All tests passed!") + + print("=" * 80) + print("Testing destindex_copy_indexer_ks") + print("=" * 80) + + # Test parameters + q_seq_len = 50 + head_dim = 128 + large_size = 1024 + dtype = torch.bfloat16 + fp8_type = torch.float8_e4m3fn + + # Create random destination indices + dest_loc = torch.randint(0, large_size, (q_seq_len,), device="cuda", dtype=torch.int32).unique() + actual_seq_len = len(dest_loc) + + # Create input tensors + k_bf16 = torch.randn((actual_seq_len, head_dim), dtype=dtype, device="cuda") + + # Quantize to FP8 + k_abs_max = k_bf16.abs().max(dim=1, keepdim=True)[0].clamp(min=1e-12) + k_scale = (k_abs_max / torch.finfo(fp8_type).max).to(torch.float32) + k_fp8 = (k_bf16 / k_abs_max).clamp( + torch.finfo(fp8_type).min, torch.finfo(fp8_type).max + ).to(fp8_type) + + # Create output buffer (as uint8 to allow reinterpretation) + o_buffer_uint8 = torch.zeros((large_size, 1, 132), dtype=torch.uint8, device="cuda") + + # Run kernel + destindex_copy_indexer_ks(k_fp8, k_scale, dest_loc, o_buffer_uint8) + + # Extract results + k_fp8_out = o_buffer_uint8[:, 0, :128].view(fp8_type) + + # Extract scale by reinterpreting 4 bytes as float32 + scale_bytes = o_buffer_uint8[:, 0, 128:132].contiguous() + k_scale_out = scale_bytes.view(-1, 4).view(torch.float32).squeeze(-1) + + # Verify results at destination locations + k_fp8_extracted = k_fp8_out[dest_loc] + k_scale_extracted = k_scale_out[dest_loc] + + # Check FP8 values match + fp8_match = torch.allclose( + k_fp8_extracted.to(torch.float32), + k_fp8.to(torch.float32), + atol=0, rtol=0 + ) + + # Check scales match + scale_match = torch.allclose( + k_scale_extracted, + k_scale.squeeze(-1), + atol=1e-6, rtol=1e-5 + ) + + # Check dequantized values + k_dequant_out = k_fp8_extracted.to(dtype) * k_scale_extracted.unsqueeze(-1) + cosine_sim = F.cosine_similarity(k_dequant_out, k_bf16, dim=-1).mean() + + print(f"Test with seq_len={actual_seq_len}, head_dim={head_dim}") + print(f" FP8 values match: {fp8_match}") + print(f" Scale values match: {scale_match}") + print(f" Cosine similarity after dequantization: {cosine_sim:.6f}") + + assert fp8_match, "FP8 values do not match!" + assert scale_match, "Scale values do not match!" + assert cosine_sim > 0.99, f"Cosine similarity too low: {cosine_sim}" + + print("✓ Basic test passed!") + print() + + # Test edge cases + print("Testing edge cases...") + + # Test with sequential indices + dest_loc_seq = torch.arange(20, device="cuda", dtype=torch.int32) + k_bf16_seq = torch.randn((20, head_dim), dtype=dtype, device="cuda") + k_abs_max_seq = k_bf16_seq.abs().max(dim=1, keepdim=True)[0].clamp(min=1e-12) + k_scale_seq = (k_abs_max_seq / torch.finfo(fp8_type).max).to(torch.float32) + k_fp8_seq = (k_bf16_seq / k_abs_max_seq).clamp( + torch.finfo(fp8_type).min, torch.finfo(fp8_type).max + ).to(fp8_type) + + o_buffer_seq = torch.zeros((large_size, 1, 132), dtype=torch.uint8, device="cuda") + destindex_copy_indexer_ks(k_fp8_seq, k_scale_seq, dest_loc_seq, o_buffer_seq) + + k_fp8_out_seq = o_buffer_seq[:20, 0, :128].view(fp8_type) + scale_bytes_seq = o_buffer_seq[:20, 0, 128:132].contiguous() + k_scale_out_seq = scale_bytes_seq.view(-1, 4).view(torch.float32).squeeze(-1) + + fp8_match_seq = torch.allclose( + k_fp8_out_seq.to(torch.float32), + k_fp8_seq.to(torch.float32), + atol=0, rtol=0 + ) + scale_match_seq = torch.allclose( + k_scale_out_seq, + k_scale_seq.squeeze(-1), + atol=1e-6, rtol=1e-5 + ) + + print(f" Sequential indices test: FP8={fp8_match_seq}, Scale={scale_match_seq}") + assert fp8_match_seq and scale_match_seq + print("✓ Edge case tests passed!") + print() + + # Test with single element + print("Testing single element...") + dest_loc_single = torch.tensor([42], device="cuda", dtype=torch.int32) + k_bf16_single = torch.randn((1, head_dim), dtype=dtype, device="cuda") + k_abs_max_single = k_bf16_single.abs().max(dim=1, keepdim=True)[0].clamp(min=1e-12) + k_scale_single = (k_abs_max_single / torch.finfo(fp8_type).max).to(torch.float32) + k_fp8_single = (k_bf16_single / k_abs_max_single).clamp( + torch.finfo(fp8_type).min, torch.finfo(fp8_type).max + ).to(fp8_type) + + o_buffer_single = torch.zeros((large_size, 1, 132), dtype=torch.uint8, device="cuda") + destindex_copy_indexer_ks(k_fp8_single, k_scale_single, dest_loc_single, o_buffer_single) + + k_fp8_out_single = o_buffer_single[42:43, 0, :128].view(fp8_type) + scale_bytes_single = o_buffer_single[42:43, 0, 128:132].contiguous() + k_scale_out_single = scale_bytes_single.view(-1, 4).view(torch.float32).squeeze(-1) + + fp8_match_single = torch.allclose( + k_fp8_out_single.to(torch.float32), + k_fp8_single.to(torch.float32), + atol=0, rtol=0 + ) + scale_match_single = torch.allclose( + k_scale_out_single, + k_scale_single.squeeze(-1), + atol=1e-6, rtol=1e-5 + ) + + print(f" Single element test: FP8={fp8_match_single}, Scale={scale_match_single}") + assert fp8_match_single and scale_match_single + print("✓ Single element test passed!") + print() + + print("=" * 80) + print("All tests passed successfully! ✓") + print("=" * 80) if __name__ == "__main__": - test() + test_destindex_copy_indexer_ks() \ No newline at end of file diff --git a/lightllm/models/deepseek3_2/triton_kernel/extract_indexer_ks.py b/lightllm/models/deepseek3_2/triton_kernel/extract_indexer_ks.py index e97454ba2e..eb22fbb8f7 100644 --- a/lightllm/models/deepseek3_2/triton_kernel/extract_indexer_ks.py +++ b/lightllm/models/deepseek3_2/triton_kernel/extract_indexer_ks.py @@ -1,156 +1,309 @@ import torch + import triton import triton.language as tl -import numpy @triton.jit def _fwd_kernel_extract_indexer_ks( - buffer_fp8, - buffer_scale, - mem_index, - k_fp8_out, - k_scale_out, - stride_buffer_fp8_bs, - stride_buffer_fp8_h, - stride_buffer_fp8_d, - stride_buffer_scale_bs, - stride_buffer_scale_h, - stride_buffer_scale_d, - stride_k_fp8_out_bs, - stride_k_fp8_out_d, - stride_k_scale_out_bs, + I_buffer, # Input buffer [large_size, 1, 132] uint8 + SrcLoc, # Source indices [req_size] int32/int64 + O_fp8, # Output FP8 [req_size, 128] float8_e4m3fn + O_scale, # Output scale [req_size] float32 + stride_i_bs, + stride_i_h, + stride_i_d, + stride_o_fp8_bs, + stride_o_fp8_d, + stride_o_scale_bs, BLOCK_DMODEL: tl.constexpr, ): + """ + Triton kernel to extract FP8 K values and their scales from an indexed buffer. + + This kernel is the inverse of destindex_copy_indexer_ks. It reads from a + compact buffer format where each entry contains: + - Bytes 0-127: FP8 key values (128 bytes) + - Bytes 128-131: Float32 scale (4 bytes) + + The source location for each output element is specified by SrcLoc. + """ cur_index = tl.program_id(0) - - # Load the memory index - mem_idx = tl.load(mem_index + cur_index).to(tl.int64) - - # Load k_fp8 data from buffer_fp8[mem_idx, 0, :] offs_d = tl.arange(0, BLOCK_DMODEL) - k_fp8_ptrs = buffer_fp8 + mem_idx * stride_buffer_fp8_bs + 0 * stride_buffer_fp8_h + offs_d * stride_buffer_fp8_d - k_fp8_data = tl.load(k_fp8_ptrs) - - # Load k_scale data from buffer_scale[mem_idx, 0, 0] - k_scale_ptr = buffer_scale + mem_idx * stride_buffer_scale_bs + 0 * stride_buffer_scale_h + 0 * stride_buffer_scale_d - k_scale_data = tl.load(k_scale_ptr) - - # Store k_fp8 output - k_fp8_out_ptrs = k_fp8_out + cur_index * stride_k_fp8_out_bs + offs_d * stride_k_fp8_out_d - tl.store(k_fp8_out_ptrs, k_fp8_data) - - # Store k_scale output - k_scale_out_ptr = k_scale_out + cur_index * stride_k_scale_out_bs - tl.store(k_scale_out_ptr, k_scale_data) + + # Load source index for this thread + src_index = tl.load(SrcLoc + cur_index).to(tl.int64) + + # Load K_fp8 from I_buffer[:, 0, :128] + i_k_ptrs = I_buffer + src_index * stride_i_bs + stride_i_d * offs_d + k_fp8_as_uint8 = tl.load(i_k_ptrs) + + # Convert uint8 to fp8 through bitcast + k_fp8 = k_fp8_as_uint8.to(tl.float8e4nv, bitcast=True) + + # Store K_fp8 to output + o_k_ptrs = O_fp8 + cur_index * stride_o_fp8_bs + stride_o_fp8_d * offs_d + tl.store(o_k_ptrs, k_fp8) + + # Load K_scale from I_buffer[:, 0, 128:132] (4 bytes for float32) + # Load 4 bytes and reconstruct float32 (little-endian) + i_scale_base_ptr = I_buffer + src_index * stride_i_bs + BLOCK_DMODEL * stride_i_d + + # Load 4 bytes individually and combine them into uint32 + byte0 = tl.load(i_scale_base_ptr + 0 * stride_i_d).to(tl.uint32) + byte1 = tl.load(i_scale_base_ptr + 1 * stride_i_d).to(tl.uint32) + byte2 = tl.load(i_scale_base_ptr + 2 * stride_i_d).to(tl.uint32) + byte3 = tl.load(i_scale_base_ptr + 3 * stride_i_d).to(tl.uint32) + + # Combine bytes into uint32 (little-endian: byte0 is LSB) + scale_as_uint32 = byte0 | (byte1 << 8) | (byte2 << 16) | (byte3 << 24) + + # Bitcast uint32 to float32 + k_scale = scale_as_uint32.to(tl.float32, bitcast=True) + + # Store scale to output + o_scale_ptr = O_scale + cur_index * stride_o_scale_bs + tl.store(o_scale_ptr, k_scale) + + return @torch.no_grad() -def extract_indexer_ks(buffer, mem_index): +def extract_indexer_ks(I_buffer: torch.Tensor, SrcLoc: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """ - Extract k_fp8 and k_scale from the indexer memory buffer using Triton kernel. - + Extract FP8-quantized key values and their scales from indexed locations in a buffer. + + This function is the inverse operation of destindex_copy_indexer_ks. It's used in + the DeepSeek-V3.2 NSA (Neighbor-aware Sparse Attention) mechanism to retrieve + compressed key representations from a memory buffer. + Args: - buffer: Memory buffer of shape [total_tokens, heads, 132] with dtype uint8 - mem_index: Indices tensor of shape [seq_len] with dtype int32/int64 - + I_buffer: [large_size, 1, 132] torch.uint8 + Input buffer containing packed FP8 keys and float32 scales. + Format: [:, 0, :128] = FP8 keys, [:, 0, 128:132] = float32 scales + SrcLoc: [req_size] torch.int32 or torch.int64 + Source indices to extract from the input buffer + Returns: - k_fp8: Tensor of shape [seq_len, 128] with dtype float8_e4m3fn - k_scale: Tensor of shape [seq_len] with dtype float32 + tuple containing: + - K_fp8: [req_size, 128] torch.float8_e4m3fn + FP8-quantized key values + - K_scale: [req_size] torch.float32 + Quantization scales for each key + + Example: + >>> i_buffer = torch.zeros(1024, 1, 132, dtype=torch.uint8).cuda() + >>> src_loc = torch.tensor([10, 20, 30], dtype=torch.int32).cuda() + >>> k_fp8, k_scale = extract_indexer_ks(i_buffer, src_loc) + >>> # k_fp8.shape == [3, 128], k_scale.shape == [3] """ - seq_len = mem_index.shape[0] - assert buffer.shape[2] == 132, f"buffer dim should be 132, got {buffer.shape[2]}" - - # Reinterpret buffer as the appropriate types for Triton - buffer_fp8 = buffer[:, :, :128].view(torch.float8_e4m3fn) - buffer_scale = buffer[:, :, 128:132].view(torch.float32)[:, :, :1] - - # Prepare output tensors - k_fp8_out = torch.empty((seq_len, 128), dtype=torch.float8_e4m3fn, device=buffer.device) - k_scale_out = torch.empty((seq_len,), dtype=torch.float32, device=buffer.device) - - BLOCK_DMODEL = 128 - grid = (seq_len,) + req_size = SrcLoc.shape[0] + head_dim = 128 + + assert I_buffer.dtype == torch.uint8, f"Expected I_buffer dtype=uint8, got {I_buffer.dtype}" + assert I_buffer.shape[2] == 132, f"Expected I_buffer last dim=132, got {I_buffer.shape[2]}" + + # Allocate output tensors + O_fp8 = torch.empty((req_size, head_dim), dtype=torch.float8_e4m3fn, device=I_buffer.device) + O_scale = torch.empty((req_size,), dtype=torch.float32, device=I_buffer.device) + + grid = (req_size,) num_warps = 1 - + _fwd_kernel_extract_indexer_ks[grid]( - buffer_fp8, - buffer_scale, - mem_index, - k_fp8_out, - k_scale_out, - buffer_fp8.stride(0), - buffer_fp8.stride(1), - buffer_fp8.stride(2), - buffer_scale.stride(0), - buffer_scale.stride(1), - buffer_scale.stride(2), - k_fp8_out.stride(0), - k_fp8_out.stride(1), - k_scale_out.stride(0), - BLOCK_DMODEL=BLOCK_DMODEL, + I_buffer, + SrcLoc, + O_fp8, + O_scale, + I_buffer.stride(0), + I_buffer.stride(1), + I_buffer.stride(2), + O_fp8.stride(0), + O_fp8.stride(1), + O_scale.stride(0), + BLOCK_DMODEL=head_dim, num_warps=num_warps, num_stages=1, ) + + return O_fp8, O_scale - return k_fp8_out, k_scale_out - - -def test(): - # Test parameters similar to the usage in nsa_indexer_layer_inder.py - B, N_CTX, H = 4, 1024, 1 # batch_size, seq_len, heads (always 1 for this) - seq_len = 50 # number of tokens to extract - dtype_fp8 = torch.float8_e4m3fn - dtype_scale = torch.float32 - # Create test buffer [total_tokens, heads, 132] as uint8 - buffer = torch.zeros((B * N_CTX, H, 132), dtype=torch.uint8).cuda() - - # Fill with test data - simulate what destindex_copy_indexer_ks does - test_indices = torch.randint(0, B * N_CTX, (seq_len,), dtype=torch.int32).cuda() - # Generate fp8 data by converting from float32 - test_k_fp8_fp32 = torch.randn((seq_len, 128), dtype=torch.float32).cuda() - test_k_fp8 = test_k_fp8_fp32.to(dtype_fp8) - test_k_scale = torch.randn((seq_len,), dtype=dtype_scale).cuda() - - # Manually populate buffer as destindex_copy_indexer_ks would - for i in range(seq_len): - dest_idx = test_indices[i].item() - # Store fp8 data - buffer[dest_idx, 0, :128] = test_k_fp8[i].view(torch.uint8) - # Store scale data (4 bytes) - need to convert float32 to bytes - scale_bytes = test_k_scale[i].cpu().numpy().tobytes() - scale_bytes_np = numpy.frombuffer(scale_bytes, dtype=numpy.uint8) - buffer[dest_idx, 0, 128:132] = torch.from_numpy(scale_bytes_np).to(buffer.device) - - # Call our extraction function - extracted_fp8, extracted_scale = extract_indexer_ks(buffer, test_indices) - - # Verify results - print(f"Original k_fp8 shape: {test_k_fp8.shape}, dtype: {test_k_fp8.dtype}") - print(f"Extracted k_fp8 shape: {extracted_fp8.shape}, dtype: {extracted_fp8.dtype}") - print(f"Original k_scale shape: {test_k_scale.shape}, dtype: {test_k_scale.dtype}") - print(f"Extracted k_scale shape: {extracted_scale.shape}, dtype: {extracted_scale.dtype}") - - # Check if extraction matches (convert fp8 to float32 for comparison) - # Use higher tolerance for fp8 due to quantization precision - fp8_match = torch.allclose(test_k_fp8_fp32, extracted_fp8.float(), atol=0.1, rtol=0.1) - scale_match = torch.allclose(test_k_scale, extracted_scale, atol=1e-6) - - print(f"FP8 data matches: {fp8_match}") - print(f"Scale data matches: {scale_match}") - - if fp8_match and scale_match: - print("All tests passed!") - else: - print("Test failed!") - if not fp8_match: - print("First few fp8 values:") - print(f"Original: {test_k_fp8_fp32[0, :5]}") - print(f"Extracted: {extracted_fp8.float()[0, :5]}") - if not scale_match: - print(f"Max scale diff: {torch.max(torch.abs(test_k_scale - extracted_scale))}") +def test_extract_indexer_ks(): + """Test the extract_indexer_ks kernel against the copy kernel""" + import torch.nn.functional as F + from lightllm.models.deepseek3_2.triton_kernel.destindex_copy_indexer_ks import destindex_copy_indexer_ks + + print("=" * 80) + print("Testing extract_indexer_ks") + print("=" * 80) + + # Test parameters + q_seq_len = 50 + head_dim = 128 + large_size = 1024 + dtype = torch.bfloat16 + fp8_type = torch.float8_e4m3fn + + # Create random indices for writing + write_indices = torch.randint(0, large_size, (q_seq_len,), device="cuda", dtype=torch.int32).unique() + actual_seq_len = len(write_indices) + + # Create input tensors + k_bf16_original = torch.randn((actual_seq_len, head_dim), dtype=dtype, device="cuda") + + # Quantize to FP8 + k_abs_max = k_bf16_original.abs().max(dim=1, keepdim=True)[0].clamp(min=1e-12) + k_scale_original = (k_abs_max / torch.finfo(fp8_type).max).to(torch.float32) + k_fp8_original = (k_bf16_original / k_abs_max).clamp( + torch.finfo(fp8_type).min, torch.finfo(fp8_type).max + ).to(fp8_type) + + # Create buffer and write data using destindex_copy_indexer_ks + buffer = torch.zeros((large_size, 1, 132), dtype=torch.uint8, device="cuda") + destindex_copy_indexer_ks(k_fp8_original, k_scale_original, write_indices, buffer) + + # Now extract the data back using extract_indexer_ks + k_fp8_extracted, k_scale_extracted = extract_indexer_ks(buffer, write_indices) + + # Verify FP8 values match + fp8_match = torch.allclose( + k_fp8_extracted.to(torch.float32), + k_fp8_original.to(torch.float32), + atol=0, rtol=0 + ) + + # Verify scales match + scale_match = torch.allclose( + k_scale_extracted, + k_scale_original.squeeze(-1), + atol=1e-6, rtol=1e-5 + ) + + # Check dequantized values + k_dequant_extracted = k_fp8_extracted.to(dtype) * k_scale_extracted.unsqueeze(-1) + cosine_sim = F.cosine_similarity(k_dequant_extracted, k_bf16_original, dim=-1).mean() + + print(f"Test with seq_len={actual_seq_len}, head_dim={head_dim}") + print(f" FP8 values match: {fp8_match}") + print(f" Scale values match: {scale_match}") + print(f" Cosine similarity after dequantization: {cosine_sim:.6f}") + + assert fp8_match, "FP8 values do not match!" + assert scale_match, "Scale values do not match!" + assert cosine_sim > 0.99, f"Cosine similarity too low: {cosine_sim}" + + print("✓ Basic test passed!") + print() + + # Test with sequential indices + print("Testing sequential indices...") + write_indices_seq = torch.arange(20, device="cuda", dtype=torch.int32) + k_bf16_seq = torch.randn((20, head_dim), dtype=dtype, device="cuda") + k_abs_max_seq = k_bf16_seq.abs().max(dim=1, keepdim=True)[0].clamp(min=1e-12) + k_scale_seq = (k_abs_max_seq / torch.finfo(fp8_type).max).to(torch.float32) + k_fp8_seq = (k_bf16_seq / k_abs_max_seq).clamp( + torch.finfo(fp8_type).min, torch.finfo(fp8_type).max + ).to(fp8_type) + + buffer_seq = torch.zeros((large_size, 1, 132), dtype=torch.uint8, device="cuda") + destindex_copy_indexer_ks(k_fp8_seq, k_scale_seq, write_indices_seq, buffer_seq) + k_fp8_ext_seq, k_scale_ext_seq = extract_indexer_ks(buffer_seq, write_indices_seq) + + fp8_match_seq = torch.allclose( + k_fp8_ext_seq.to(torch.float32), + k_fp8_seq.to(torch.float32), + atol=0, rtol=0 + ) + scale_match_seq = torch.allclose( + k_scale_ext_seq, + k_scale_seq.squeeze(-1), + atol=1e-6, rtol=1e-5 + ) + + print(f" Sequential indices: FP8={fp8_match_seq}, Scale={scale_match_seq}") + assert fp8_match_seq and scale_match_seq + print("✓ Sequential test passed!") + print() + + # Test with single element + print("Testing single element...") + write_idx_single = torch.tensor([42], device="cuda", dtype=torch.int32) + k_bf16_single = torch.randn((1, head_dim), dtype=dtype, device="cuda") + k_abs_max_single = k_bf16_single.abs().max(dim=1, keepdim=True)[0].clamp(min=1e-12) + k_scale_single = (k_abs_max_single / torch.finfo(fp8_type).max).to(torch.float32) + k_fp8_single = (k_bf16_single / k_abs_max_single).clamp( + torch.finfo(fp8_type).min, torch.finfo(fp8_type).max + ).to(fp8_type) + + buffer_single = torch.zeros((large_size, 1, 132), dtype=torch.uint8, device="cuda") + destindex_copy_indexer_ks(k_fp8_single, k_scale_single, write_idx_single, buffer_single) + k_fp8_ext_single, k_scale_ext_single = extract_indexer_ks(buffer_single, write_idx_single) + + fp8_match_single = torch.allclose( + k_fp8_ext_single.to(torch.float32), + k_fp8_single.to(torch.float32), + atol=0, rtol=0 + ) + scale_match_single = torch.allclose( + k_scale_ext_single, + k_scale_single.squeeze(-1), + atol=1e-6, rtol=1e-5 + ) + + print(f" Single element: FP8={fp8_match_single}, Scale={scale_match_single}") + assert fp8_match_single and scale_match_single + print("✓ Single element test passed!") + print() + + # Test with larger batch to check performance characteristics + print("Testing larger batch (performance check)...") + write_indices_large = torch.randint(0, large_size * 10, (500,), device="cuda", dtype=torch.int32).unique() + actual_large_len = len(write_indices_large) + k_bf16_large = torch.randn((actual_large_len, head_dim), dtype=dtype, device="cuda") + k_abs_max_large = k_bf16_large.abs().max(dim=1, keepdim=True)[0].clamp(min=1e-12) + k_scale_large = (k_abs_max_large / torch.finfo(fp8_type).max).to(torch.float32) + k_fp8_large = (k_bf16_large / k_abs_max_large).clamp( + torch.finfo(fp8_type).min, torch.finfo(fp8_type).max + ).to(fp8_type) + + buffer_large = torch.zeros((large_size * 10, 1, 132), dtype=torch.uint8, device="cuda") + destindex_copy_indexer_ks(k_fp8_large, k_scale_large, write_indices_large, buffer_large) + + # Warm up + for _ in range(3): + _ = extract_indexer_ks(buffer_large, write_indices_large) + + # Time it + torch.cuda.synchronize() + import time + start = time.time() + for _ in range(100): + k_fp8_ext_large, k_scale_ext_large = extract_indexer_ks(buffer_large, write_indices_large) + torch.cuda.synchronize() + elapsed = time.time() - start + + fp8_match_large = torch.allclose( + k_fp8_ext_large.to(torch.float32), + k_fp8_large.to(torch.float32), + atol=0, rtol=0 + ) + scale_match_large = torch.allclose( + k_scale_ext_large, + k_scale_large.squeeze(-1), + atol=1e-6, rtol=1e-5 + ) + + print(f" Large batch (size={actual_large_len}): FP8={fp8_match_large}, Scale={scale_match_large}") + print(f" Average time per call: {elapsed/100*1000:.3f} ms") + assert fp8_match_large and scale_match_large + print("✓ Large batch test passed!") + print() + + print("=" * 80) + print("All tests passed successfully! ✓") + print("=" * 80) if __name__ == "__main__": - test() + test_extract_indexer_ks() From bcafce3e68e88b0cd84d20c465cf1ede41e7b602 Mon Sep 17 00:00:00 2001 From: sufubao Date: Mon, 10 Nov 2025 14:32:31 +0000 Subject: [PATCH 12/58] fix cudagraph --- lightllm/models/deepseek3_2/infer_struct.py | 168 +++++++++++++++++--- 1 file changed, 147 insertions(+), 21 deletions(-) diff --git a/lightllm/models/deepseek3_2/infer_struct.py b/lightllm/models/deepseek3_2/infer_struct.py index c122c6a7e7..db6e61a1c8 100644 --- a/lightllm/models/deepseek3_2/infer_struct.py +++ b/lightllm/models/deepseek3_2/infer_struct.py @@ -1,8 +1,10 @@ import torch +import weakref from lightllm.models.deepseek2.flashattention_infer_struct import Deepseek2FlashAttentionStateInfo from lightllm.models.deepseek3_2.mem_manager import Deepseek3_2MemoryManager class Deepseek3_2FlashAttentionStateInfo(Deepseek2FlashAttentionStateInfo): + _shared_nsa_buffers = None def __init__(self): super().__init__() @@ -14,8 +16,42 @@ def __init__(self): self.index_topk = 2048 return + @classmethod + def get_nsa_buffers(cls, graph_max_batch_size: int, max_seq_len: int): + """Get or create pre-allocated buffers for CUDA graph execution""" + if cls._shared_nsa_buffers is None: + # Pre-allocate buffers for max possible sizes + max_total_q_tokens = graph_max_batch_size * max_seq_len + max_total_tokens = graph_max_batch_size * max_seq_len + + cls._shared_nsa_buffers = [ + { + 'ks': torch.empty(max_total_q_tokens, dtype=torch.int, device='cuda'), + 'ke': torch.empty(max_total_q_tokens, dtype=torch.int, device='cuda'), + 'lengths': torch.empty(max_total_q_tokens, dtype=torch.int, device='cuda'), + 'page_table_size_1': torch.empty(graph_max_batch_size, max_seq_len, dtype=torch.int, device='cuda'), + 'req_all_mem_index': torch.empty(max_total_tokens, dtype=torch.int64, device='cuda'), + 'nsa_cache_seqlens': torch.empty(graph_max_batch_size, dtype=torch.int32, device='cuda'), + 'nsa_cu_seqlens_k': torch.empty(graph_max_batch_size + 1, dtype=torch.int32, device='cuda'), + }, + { # Second buffer for microbatch overlap if needed + 'ks': torch.empty(max_total_q_tokens, dtype=torch.int, device='cuda'), + 'ke': torch.empty(max_total_q_tokens, dtype=torch.int, device='cuda'), + 'lengths': torch.empty(max_total_q_tokens, dtype=torch.int, device='cuda'), + 'page_table_size_1': torch.empty(graph_max_batch_size, max_seq_len, dtype=torch.int, device='cuda'), + 'req_all_mem_index': torch.empty(max_total_tokens, dtype=torch.int64, device='cuda'), + 'nsa_cache_seqlens': torch.empty(graph_max_batch_size, dtype=torch.int32, device='cuda'), + 'nsa_cu_seqlens_k': torch.empty(graph_max_batch_size + 1, dtype=torch.int32, device='cuda'), + } + ] + return cls._shared_nsa_buffers + def init_some_extra_state(self, model, input_ids: torch.Tensor): super().init_some_extra_state(model, input_ids) + + # Store weak reference to model for accessing graph parameters + self._model_ref = weakref.ref(model) + assert isinstance(self.mem_manager, Deepseek3_2MemoryManager) self.indexer_ks_mem_manager = self.mem_manager.indexer_ks_mem_manager @@ -29,11 +65,34 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor): if self.b_ready_cache_len is None: self.b_ready_cache_len = self.b_seq_len - self.b_q_seq_len - self.nsa_cache_seqlens = self.b_att_seq_len.clamp(max=self.index_topk) + # Check if we can use CUDA graph based on batch size and max_len constraints + use_cuda_graph_buffers = False + if (hasattr(model, 'graph_max_batch_size') and + hasattr(model, 'graph_max_len_in_batch') and + self.batch_size <= model.graph_max_batch_size and + self.max_len_in_batch <= model.graph_max_len_in_batch): + use_cuda_graph_buffers = True + + # Setup nsa_cache_seqlens and nsa_cu_seqlens_k with pre-allocated buffers if using CUDA graph + if use_cuda_graph_buffers: + buffers = self.get_nsa_buffers(model.graph_max_batch_size, model.graph_max_len_in_batch) + buffer = buffers[self.microbatch_index] + + # Use views into pre-allocated buffers + self.nsa_cache_seqlens = buffer['nsa_cache_seqlens'][:self.batch_size] + self.nsa_cu_seqlens_k = buffer['nsa_cu_seqlens_k'][:self.batch_size + 1] + else: + # Create new tensors dynamically + self.nsa_cache_seqlens = torch.empty(self.batch_size, dtype=torch.int32, device='cuda') + self.nsa_cu_seqlens_k = torch.empty(self.batch_size + 1, dtype=torch.int32, device='cuda') + + # Calculate actual values + self.nsa_cache_seqlens.copy_(self.b_att_seq_len.clamp(max=self.index_topk)) assert self.nsa_cache_seqlens.dtype == torch.int32 - self.nsa_cu_seqlens_k = torch.nn.functional.pad( - torch.cumsum(self.nsa_cache_seqlens, dim=0, dtype=torch.int32), (1, 0) - ) + + # Compute cumulative sum with padding + torch.cumsum(self.nsa_cache_seqlens, dim=0, dtype=torch.int32, out=self.nsa_cu_seqlens_k[1:]) + self.nsa_cu_seqlens_k[0] = 0 # Pre-compute NSA indexer indexing structures self._init_nsa_indexing_structures() @@ -46,22 +105,89 @@ def _init_nsa_indexing_structures(self): lengths_list = [] offset = 0 num_seq_len = self.b_req_idx.shape[0] - self.page_table_size_1 = torch.zeros((num_seq_len, self.b_seq_len.max()), dtype=torch.int, device='cuda') + max_seq_len = self.b_seq_len.max().item() + + # Calculate total sizes needed + total_q_len = sum(self.b_q_seq_len[i].item() for i in range(num_seq_len)) + total_seq_len = sum(self.b_seq_len[i].item() for i in range(num_seq_len)) + + # Check if we should use CUDA graph buffers + use_cuda_graph_buffers = False + if hasattr(self, '_model_ref'): + model = self._model_ref() + if (model is not None and + hasattr(model, 'graph_max_batch_size') and + hasattr(model, 'graph_max_len_in_batch') and + self.batch_size <= model.graph_max_batch_size and + self.max_len_in_batch <= model.graph_max_len_in_batch): + use_cuda_graph_buffers = True + + if use_cuda_graph_buffers: + # Use pre-allocated buffers for CUDA graph + model = self._model_ref() + buffers = self.get_nsa_buffers(model.graph_max_batch_size, model.graph_max_len_in_batch) + buffer = buffers[self.microbatch_index] + + # Use views into pre-allocated buffers + self.ks = buffer['ks'][:total_q_len] + self.ke = buffer['ke'][:total_q_len] + self.lengths = buffer['lengths'][:total_q_len] + self.page_table_size_1 = buffer['page_table_size_1'][:num_seq_len, :max_seq_len] + self.req_all_mem_index = buffer['req_all_mem_index'][:total_seq_len] + + # Zero out page_table_size_1 before filling + self.page_table_size_1.zero_() + + # Compute and copy values into the pre-allocated buffer views + ks_offset = 0 + ke_offset = 0 + lengths_offset = 0 + req_offset = 0 + seq_offset = 0 + + for i in range(num_seq_len): + seq_len = self.b_seq_len[i].item() + q_seq_len = self.b_q_seq_len[i].item() + mem_index = self.req_manager.req_to_token_indexs[i, :seq_len] + + # Copy req_all_mem_index + self.req_all_mem_index[req_offset:req_offset + seq_len] = mem_index + + # Fill page_table_size_1 + self.page_table_size_1[i, :seq_len] = mem_index + + # Fill ks, ke, lengths + self.ks[ks_offset:ks_offset + q_seq_len].fill_(seq_offset) + self.ke[ke_offset:ke_offset + q_seq_len] = torch.arange( + seq_offset + 1, seq_offset + q_seq_len + 1, dtype=torch.int, device='cuda' + ) + self.lengths[lengths_offset:lengths_offset + q_seq_len] = torch.arange( + seq_len - q_seq_len + 1, seq_len + 1, dtype=torch.int, device='cuda' + ) + + ks_offset += q_seq_len + ke_offset += q_seq_len + lengths_offset += q_seq_len + req_offset += seq_len + seq_offset += seq_len + else: + # Original dynamic allocation for non-CUDA graph mode + self.page_table_size_1 = torch.zeros((num_seq_len, max_seq_len), dtype=torch.int, device='cuda') - for i in range(num_seq_len): - seq_len = self.b_seq_len[i] - q_seq_len = self.b_q_seq_len[i] - mem_index = self.req_manager.req_to_token_indexs[i, :seq_len] - req_all_mem_index_list.append(mem_index) - self.page_table_size_1[i, :seq_len] = mem_index - ks = torch.zeros(q_seq_len, dtype=torch.int, device='cuda') + offset - ke = torch.arange(q_seq_len, dtype=torch.int, device='cuda') + offset + 1 - ks_list.append(ks) - ke_list.append(ke) - lengths_list.append(torch.arange(seq_len - q_seq_len + 1, seq_len + 1, dtype=torch.int, device='cuda')) - offset += seq_len + for i in range(num_seq_len): + seq_len = self.b_seq_len[i].item() + q_seq_len = self.b_q_seq_len[i].item() + mem_index = self.req_manager.req_to_token_indexs[i, :seq_len] + req_all_mem_index_list.append(mem_index) + self.page_table_size_1[i, :seq_len] = mem_index + ks = torch.zeros(q_seq_len, dtype=torch.int, device='cuda') + offset + ke = torch.arange(q_seq_len, dtype=torch.int, device='cuda') + offset + 1 + ks_list.append(ks) + ke_list.append(ke) + lengths_list.append(torch.arange(seq_len - q_seq_len + 1, seq_len + 1, dtype=torch.int, device='cuda')) + offset += seq_len - self.req_all_mem_index = torch.cat(req_all_mem_index_list, dim=0) - self.ks = torch.cat(ks_list, dim=0) - self.ke = torch.cat(ke_list, dim=0) - self.lengths = torch.cat(lengths_list, dim=0) \ No newline at end of file + self.req_all_mem_index = torch.cat(req_all_mem_index_list, dim=0) + self.ks = torch.cat(ks_list, dim=0) + self.ke = torch.cat(ke_list, dim=0) + self.lengths = torch.cat(lengths_list, dim=0) \ No newline at end of file From 66138e5a2cd922d5994d80fc2a791ad664d0f6c8 Mon Sep 17 00:00:00 2001 From: sufubao Date: Fri, 26 Dec 2025 10:53:38 +0000 Subject: [PATCH 13/58] can run --- lightllm/common/basemodel/basemodel.py | 20 ++ lightllm/common/infer_utils.py | 70 +++++- .../kv_cache_mem_manager/mem_manager.py | 15 +- lightllm/models/deepseek3_2/infer_struct.py | 144 ++++++----- .../layer_infer/nsa_indexer_layer_inder.py | 87 ++++--- .../layer_infer/transformer_layer_infer.py | 33 +-- .../triton_kernel/copy_indexer_ks.py | 232 ++++++++++++++++++ .../destindex_copy_indexer_ks.py | 151 +++++------- 8 files changed, 552 insertions(+), 200 deletions(-) create mode 100644 lightllm/models/deepseek3_2/triton_kernel/copy_indexer_ks.py diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 5c1d2b8712..c11b68c99c 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -476,6 +476,24 @@ def _prefill( ) infer_state = self._create_inferstate(model_input) + + # Capture old indexer_ks positions before they are overwritten + # This is needed for DeepSeek v3.2 to copy cached tokens' indexer_ks + old_indexer_ks_positions = [] + for i in range(infer_state.b_req_idx.shape[0]): + req_idx = infer_state.b_req_idx[i].item() + ready_cache_len = infer_state.b_ready_cache_len[i].item() + + if ready_cache_len > 0: + # Capture old positions for cached tokens + old_pos = self.req_manager.req_to_token_indexs[ + req_idx, 0:ready_cache_len + ].clone() # Clone to avoid view issues + old_indexer_ks_positions.append(old_pos) + else: + # No cached tokens for this request + old_indexer_ks_positions.append(None) + init_req_to_token_indexes( req_to_token_indexs=self.req_manager.req_to_token_indexs, b_req_idx=infer_state.b_req_idx, @@ -484,6 +502,8 @@ def _prefill( b_start_loc=model_input.b_prefill_start_loc, alloc_mem_index=infer_state.mem_index, max_q_seq_len=infer_state.max_q_seq_len, + mem_manager=self.req_manager.mem_manager, + old_indexer_ks_positions=old_indexer_ks_positions, ) prefill_mem_indexes_ready_event = torch.cuda.Event() prefill_mem_indexes_ready_event.record() diff --git a/lightllm/common/infer_utils.py b/lightllm/common/infer_utils.py index e1b9cc3830..ed3c0b73e4 100644 --- a/lightllm/common/infer_utils.py +++ b/lightllm/common/infer_utils.py @@ -2,8 +2,17 @@ def init_req_to_token_indexes( - req_to_token_indexs, b_req_idx, b_seq_len, b_ready_cache_len, b_start_loc, alloc_mem_index, max_q_seq_len + req_to_token_indexs, + b_req_idx, + b_seq_len, + b_ready_cache_len, + b_start_loc, + alloc_mem_index, + max_q_seq_len, + mem_manager=None, + old_indexer_ks_positions=None, ): + # Step 1: Copy KV cache for NEW tokens (existing logic) copy_kv_index_to_req_prefill( req_to_token_indexs=req_to_token_indexs, b_req_idx=b_req_idx, @@ -13,3 +22,62 @@ def init_req_to_token_indexes( memindex=alloc_mem_index, max_q_seq_len=max_q_seq_len, ) + + # Step 2: Copy indexer_ks for CACHED tokens (DeepSeek v3.2 specific) + # This ensures consistency between KV cache and indexer_ks buffers + # when prefix cache is hit + if ( + mem_manager is not None + and hasattr(mem_manager, "indexer_ks_mem_manager") + and old_indexer_ks_positions is not None + ): + + _copy_cached_indexer_ks_to_new_positions( + req_to_token_indexs=req_to_token_indexs, + b_req_idx=b_req_idx, + b_ready_cache_len=b_ready_cache_len, + mem_manager=mem_manager, + old_indexer_ks_positions=old_indexer_ks_positions, + ) + + +def _copy_cached_indexer_ks_to_new_positions( + req_to_token_indexs, + b_req_idx, + b_ready_cache_len, + mem_manager, + old_indexer_ks_positions, +): + """ + Copy cached tokens' indexer_ks from old positions to new positions. + + This function is called after copy_kv_index_to_req_prefill() has updated + req_to_token_indexs to point to new contiguous positions. We need to copy + indexer_ks data to match the KV cache layout. + + For each layer and each request with cached tokens: + - Copy indexer_ks data from old positions to new positions + - This ensures consistency when using extract_indexer_ks later + """ + from lightllm.models.deepseek3_2.triton_kernel.copy_indexer_ks import copy_indexer_ks + + # Get number of layers from indexer_ks_mem_manager + num_layers = len(mem_manager.indexer_ks_mem_manager.kv_buffer) + indexer_buffer = mem_manager.indexer_ks_mem_manager.kv_buffer + + for layer_idx in range(num_layers): + for i in range(b_req_idx.shape[0]): + req_idx = b_req_idx[i].item() + ready_cache_len = b_ready_cache_len[i].item() + old_positions = old_indexer_ks_positions[i] + + if ready_cache_len > 0 and old_positions is not None: + # New positions after copy_kv_index_to_req_prefill + new_positions = req_to_token_indexs[req_idx, 0:ready_cache_len] + + # Copy indexer_ks: old_positions -> new_positions + copy_indexer_ks( + buffer=indexer_buffer[layer_idx], + src_loc=old_positions, + dest_loc=new_positions, + ) diff --git a/lightllm/common/kv_cache_mem_manager/mem_manager.py b/lightllm/common/kv_cache_mem_manager/mem_manager.py index 2940d74e21..7d5e2af046 100755 --- a/lightllm/common/kv_cache_mem_manager/mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/mem_manager.py @@ -26,7 +26,9 @@ class MemoryManager: - def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9, is_sub_mem_manager=False): + def __init__( + self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9, is_sub_mem_manager=False + ): self.size = size self.head_num = head_num self.head_dim = head_dim @@ -93,6 +95,17 @@ def profile_size(self, mem_fraction): available_memory = get_available_gpu_memory(world_size) - total_memory * (1 - mem_fraction) cell_size = self.get_cell_size() self.size = int(available_memory * 1024 ** 3 / cell_size) + + # Ensure size is at least a minimum positive value to avoid torch.arange errors + MIN_SIZE = 1024 # Minimum 1024 tokens + if self.size < MIN_SIZE: + logger.warning( + f"Insufficient memory for KV cache. Available: {available_memory:.2f} GB, " + f"but calculated size is {self.size} tokens. Using minimum size {MIN_SIZE} tokens instead. " + f"Consider reducing model size, using fewer GPUs, or increasing mem_fraction." + ) + self.size = MIN_SIZE + if world_size > 1: tensor = torch.tensor(self.size, dtype=torch.int64, device=f"cuda:{get_current_device_id()}") dist.all_reduce(tensor, op=dist.ReduceOp.MIN) diff --git a/lightllm/models/deepseek3_2/infer_struct.py b/lightllm/models/deepseek3_2/infer_struct.py index db6e61a1c8..2f8aa75629 100644 --- a/lightllm/models/deepseek3_2/infer_struct.py +++ b/lightllm/models/deepseek3_2/infer_struct.py @@ -3,6 +3,7 @@ from lightllm.models.deepseek2.flashattention_infer_struct import Deepseek2FlashAttentionStateInfo from lightllm.models.deepseek3_2.mem_manager import Deepseek3_2MemoryManager + class Deepseek3_2FlashAttentionStateInfo(Deepseek2FlashAttentionStateInfo): _shared_nsa_buffers = None @@ -23,35 +24,35 @@ def get_nsa_buffers(cls, graph_max_batch_size: int, max_seq_len: int): # Pre-allocate buffers for max possible sizes max_total_q_tokens = graph_max_batch_size * max_seq_len max_total_tokens = graph_max_batch_size * max_seq_len - + cls._shared_nsa_buffers = [ { - 'ks': torch.empty(max_total_q_tokens, dtype=torch.int, device='cuda'), - 'ke': torch.empty(max_total_q_tokens, dtype=torch.int, device='cuda'), - 'lengths': torch.empty(max_total_q_tokens, dtype=torch.int, device='cuda'), - 'page_table_size_1': torch.empty(graph_max_batch_size, max_seq_len, dtype=torch.int, device='cuda'), - 'req_all_mem_index': torch.empty(max_total_tokens, dtype=torch.int64, device='cuda'), - 'nsa_cache_seqlens': torch.empty(graph_max_batch_size, dtype=torch.int32, device='cuda'), - 'nsa_cu_seqlens_k': torch.empty(graph_max_batch_size + 1, dtype=torch.int32, device='cuda'), + "ks": torch.empty(max_total_q_tokens, dtype=torch.int, device="cuda"), + "ke": torch.empty(max_total_q_tokens, dtype=torch.int, device="cuda"), + "lengths": torch.empty(max_total_q_tokens, dtype=torch.int, device="cuda"), + "page_table_size_1": torch.empty(graph_max_batch_size, max_seq_len, dtype=torch.int, device="cuda"), + "req_all_mem_index": torch.empty(max_total_tokens, dtype=torch.int64, device="cuda"), + "nsa_cache_seqlens": torch.empty(graph_max_batch_size, dtype=torch.int32, device="cuda"), + "nsa_cu_seqlens_k": torch.empty(graph_max_batch_size + 1, dtype=torch.int32, device="cuda"), }, { # Second buffer for microbatch overlap if needed - 'ks': torch.empty(max_total_q_tokens, dtype=torch.int, device='cuda'), - 'ke': torch.empty(max_total_q_tokens, dtype=torch.int, device='cuda'), - 'lengths': torch.empty(max_total_q_tokens, dtype=torch.int, device='cuda'), - 'page_table_size_1': torch.empty(graph_max_batch_size, max_seq_len, dtype=torch.int, device='cuda'), - 'req_all_mem_index': torch.empty(max_total_tokens, dtype=torch.int64, device='cuda'), - 'nsa_cache_seqlens': torch.empty(graph_max_batch_size, dtype=torch.int32, device='cuda'), - 'nsa_cu_seqlens_k': torch.empty(graph_max_batch_size + 1, dtype=torch.int32, device='cuda'), - } + "ks": torch.empty(max_total_q_tokens, dtype=torch.int, device="cuda"), + "ke": torch.empty(max_total_q_tokens, dtype=torch.int, device="cuda"), + "lengths": torch.empty(max_total_q_tokens, dtype=torch.int, device="cuda"), + "page_table_size_1": torch.empty(graph_max_batch_size, max_seq_len, dtype=torch.int, device="cuda"), + "req_all_mem_index": torch.empty(max_total_tokens, dtype=torch.int64, device="cuda"), + "nsa_cache_seqlens": torch.empty(graph_max_batch_size, dtype=torch.int32, device="cuda"), + "nsa_cu_seqlens_k": torch.empty(graph_max_batch_size + 1, dtype=torch.int32, device="cuda"), + }, ] return cls._shared_nsa_buffers def init_some_extra_state(self, model, input_ids: torch.Tensor): super().init_some_extra_state(model, input_ids) - + # Store weak reference to model for accessing graph parameters self._model_ref = weakref.ref(model) - + assert isinstance(self.mem_manager, Deepseek3_2MemoryManager) self.indexer_ks_mem_manager = self.mem_manager.indexer_ks_mem_manager @@ -60,36 +61,39 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor): # b_ready_cache_len is already set in basemodel.py for prefill pass else: - # In decode mode, b_ready_cache_len should be b_seq_len - b_q_seq_len - # since b_q_seq_len represents the new tokens being processed + # In decode mode, b_ready_cache_len is set by the router/scheduler + # based on actual prefix cache hits. If it's None (no prefix cache enabled), + # it should be 0, not computed from b_seq_len - b_q_seq_len if self.b_ready_cache_len is None: - self.b_ready_cache_len = self.b_seq_len - self.b_q_seq_len + self.b_ready_cache_len = torch.zeros_like(self.b_seq_len) # Check if we can use CUDA graph based on batch size and max_len constraints use_cuda_graph_buffers = False - if (hasattr(model, 'graph_max_batch_size') and - hasattr(model, 'graph_max_len_in_batch') and - self.batch_size <= model.graph_max_batch_size and - self.max_len_in_batch <= model.graph_max_len_in_batch): + if ( + hasattr(model, "graph_max_batch_size") + and hasattr(model, "graph_max_len_in_batch") + and self.batch_size <= model.graph_max_batch_size + and self.max_len_in_batch <= model.graph_max_len_in_batch + ): use_cuda_graph_buffers = True - + # Setup nsa_cache_seqlens and nsa_cu_seqlens_k with pre-allocated buffers if using CUDA graph if use_cuda_graph_buffers: buffers = self.get_nsa_buffers(model.graph_max_batch_size, model.graph_max_len_in_batch) buffer = buffers[self.microbatch_index] - + # Use views into pre-allocated buffers - self.nsa_cache_seqlens = buffer['nsa_cache_seqlens'][:self.batch_size] - self.nsa_cu_seqlens_k = buffer['nsa_cu_seqlens_k'][:self.batch_size + 1] + self.nsa_cache_seqlens = buffer["nsa_cache_seqlens"][: self.batch_size] + self.nsa_cu_seqlens_k = buffer["nsa_cu_seqlens_k"][: self.batch_size + 1] else: # Create new tensors dynamically - self.nsa_cache_seqlens = torch.empty(self.batch_size, dtype=torch.int32, device='cuda') - self.nsa_cu_seqlens_k = torch.empty(self.batch_size + 1, dtype=torch.int32, device='cuda') - + self.nsa_cache_seqlens = torch.empty(self.batch_size, dtype=torch.int32, device="cuda") + self.nsa_cu_seqlens_k = torch.empty(self.batch_size + 1, dtype=torch.int32, device="cuda") + # Calculate actual values self.nsa_cache_seqlens.copy_(self.b_att_seq_len.clamp(max=self.index_topk)) assert self.nsa_cache_seqlens.dtype == torch.int32 - + # Compute cumulative sum with padding torch.cumsum(self.nsa_cache_seqlens, dim=0, dtype=torch.int32, out=self.nsa_cu_seqlens_k[1:]) self.nsa_cu_seqlens_k[0] = 0 @@ -106,65 +110,68 @@ def _init_nsa_indexing_structures(self): offset = 0 num_seq_len = self.b_req_idx.shape[0] max_seq_len = self.b_seq_len.max().item() - + # Calculate total sizes needed total_q_len = sum(self.b_q_seq_len[i].item() for i in range(num_seq_len)) total_seq_len = sum(self.b_seq_len[i].item() for i in range(num_seq_len)) - + # Check if we should use CUDA graph buffers use_cuda_graph_buffers = False - if hasattr(self, '_model_ref'): + if hasattr(self, "_model_ref"): model = self._model_ref() - if (model is not None and - hasattr(model, 'graph_max_batch_size') and - hasattr(model, 'graph_max_len_in_batch') and - self.batch_size <= model.graph_max_batch_size and - self.max_len_in_batch <= model.graph_max_len_in_batch): + if ( + model is not None + and hasattr(model, "graph_max_batch_size") + and hasattr(model, "graph_max_len_in_batch") + and self.batch_size <= model.graph_max_batch_size + and self.max_len_in_batch <= model.graph_max_len_in_batch + ): use_cuda_graph_buffers = True - + if use_cuda_graph_buffers: # Use pre-allocated buffers for CUDA graph model = self._model_ref() buffers = self.get_nsa_buffers(model.graph_max_batch_size, model.graph_max_len_in_batch) buffer = buffers[self.microbatch_index] - + # Use views into pre-allocated buffers - self.ks = buffer['ks'][:total_q_len] - self.ke = buffer['ke'][:total_q_len] - self.lengths = buffer['lengths'][:total_q_len] - self.page_table_size_1 = buffer['page_table_size_1'][:num_seq_len, :max_seq_len] - self.req_all_mem_index = buffer['req_all_mem_index'][:total_seq_len] - + self.ks = buffer["ks"][:total_q_len] + self.ke = buffer["ke"][:total_q_len] + self.lengths = buffer["lengths"][:total_q_len] + self.page_table_size_1 = buffer["page_table_size_1"][:num_seq_len, :max_seq_len] + self.req_all_mem_index = buffer["req_all_mem_index"][:total_seq_len] + # Zero out page_table_size_1 before filling self.page_table_size_1.zero_() - + # Compute and copy values into the pre-allocated buffer views ks_offset = 0 ke_offset = 0 lengths_offset = 0 req_offset = 0 seq_offset = 0 - + for i in range(num_seq_len): seq_len = self.b_seq_len[i].item() q_seq_len = self.b_q_seq_len[i].item() - mem_index = self.req_manager.req_to_token_indexs[i, :seq_len] - + req_idx = self.b_req_idx[i].item() + mem_index = self.req_manager.req_to_token_indexs[req_idx, :seq_len] + # Copy req_all_mem_index - self.req_all_mem_index[req_offset:req_offset + seq_len] = mem_index - + self.req_all_mem_index[req_offset : req_offset + seq_len] = mem_index + # Fill page_table_size_1 self.page_table_size_1[i, :seq_len] = mem_index - + # Fill ks, ke, lengths - self.ks[ks_offset:ks_offset + q_seq_len].fill_(seq_offset) - self.ke[ke_offset:ke_offset + q_seq_len] = torch.arange( - seq_offset + 1, seq_offset + q_seq_len + 1, dtype=torch.int, device='cuda' + self.ks[ks_offset : ks_offset + q_seq_len].fill_(seq_offset) + self.ke[ke_offset : ke_offset + q_seq_len] = torch.arange( + seq_offset + 1, seq_offset + q_seq_len + 1, dtype=torch.int, device="cuda" ) - self.lengths[lengths_offset:lengths_offset + q_seq_len] = torch.arange( - seq_len - q_seq_len + 1, seq_len + 1, dtype=torch.int, device='cuda' + self.lengths[lengths_offset : lengths_offset + q_seq_len] = torch.arange( + seq_len - q_seq_len + 1, seq_len + 1, dtype=torch.int, device="cuda" ) - + ks_offset += q_seq_len ke_offset += q_seq_len lengths_offset += q_seq_len @@ -172,22 +179,23 @@ def _init_nsa_indexing_structures(self): seq_offset += seq_len else: # Original dynamic allocation for non-CUDA graph mode - self.page_table_size_1 = torch.zeros((num_seq_len, max_seq_len), dtype=torch.int, device='cuda') + self.page_table_size_1 = torch.zeros((num_seq_len, max_seq_len), dtype=torch.int, device="cuda") for i in range(num_seq_len): seq_len = self.b_seq_len[i].item() q_seq_len = self.b_q_seq_len[i].item() - mem_index = self.req_manager.req_to_token_indexs[i, :seq_len] + req_idx = self.b_req_idx[i].item() + mem_index = self.req_manager.req_to_token_indexs[req_idx, :seq_len] req_all_mem_index_list.append(mem_index) self.page_table_size_1[i, :seq_len] = mem_index - ks = torch.zeros(q_seq_len, dtype=torch.int, device='cuda') + offset - ke = torch.arange(q_seq_len, dtype=torch.int, device='cuda') + offset + 1 + ks = torch.zeros(q_seq_len, dtype=torch.int, device="cuda") + offset + ke = torch.arange(q_seq_len, dtype=torch.int, device="cuda") + offset + 1 ks_list.append(ks) ke_list.append(ke) - lengths_list.append(torch.arange(seq_len - q_seq_len + 1, seq_len + 1, dtype=torch.int, device='cuda')) + lengths_list.append(torch.arange(seq_len - q_seq_len + 1, seq_len + 1, dtype=torch.int, device="cuda")) offset += seq_len self.req_all_mem_index = torch.cat(req_all_mem_index_list, dim=0) self.ks = torch.cat(ks_list, dim=0) self.ke = torch.cat(ke_list, dim=0) - self.lengths = torch.cat(lengths_list, dim=0) \ No newline at end of file + self.lengths = torch.cat(lengths_list, dim=0) diff --git a/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py b/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py index df045dd2d2..2f4421e742 100644 --- a/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py +++ b/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py @@ -16,6 +16,7 @@ logger = init_logger(__name__) + class NSAIndexerInfer(BaseLayerInfer): def __init__(self, layer_idx, network_config, mode=[]): super().__init__() @@ -38,13 +39,20 @@ def __init__(self, layer_idx, network_config, mode=[]): return - def ref_fp8_mqa_logits(self, q: torch.Tensor, kv: torch.Tensor, weights: torch.Tensor, - cu_seqlen_ks: torch.Tensor, cu_seqlen_ke: torch.Tensor, cost_only: bool = False): + def ref_fp8_mqa_logits( + self, + q: torch.Tensor, + kv: torch.Tensor, + weights: torch.Tensor, + cu_seqlen_ks: torch.Tensor, + cu_seqlen_ke: torch.Tensor, + cost_only: bool = False, + ): seq_len_kv = kv.shape[0] if cost_only: start = cu_seqlen_ks.clamp(min=0, max=seq_len_kv) - end = cu_seqlen_ke.clamp(min=0, max=seq_len_kv) + end = cu_seqlen_ke.clamp(min=0, max=seq_len_kv) count_ones_per_row = (end - start).clamp(min=0) return count_ones_per_row.sum() @@ -52,29 +60,31 @@ def ref_fp8_mqa_logits(self, q: torch.Tensor, kv: torch.Tensor, weights: torch.T q = q.float() k = k.float() - mask_lo = torch.arange(0, seq_len_kv, device='cuda')[None, :] >= cu_seqlen_ks[:, None] - mask_hi = torch.arange(0, seq_len_kv, device='cuda')[None, :] < cu_seqlen_ke[:, None] + mask_lo = torch.arange(0, seq_len_kv, device="cuda")[None, :] >= cu_seqlen_ks[:, None] + mask_hi = torch.arange(0, seq_len_kv, device="cuda")[None, :] < cu_seqlen_ke[:, None] mask = mask_lo & mask_hi - score = torch.einsum('mhd,nd->hmn', q, k) + score = torch.einsum("mhd,nd->hmn", q, k) logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0) - logits = logits.masked_fill(~mask, float('-inf')) + logits = logits.masked_fill(~mask, float("-inf")) cost = mask.sum() return logits, cost - def get_indices(self, hidden_states: torch.Tensor, q_lora: torch.Tensor, - infer_state: Deepseek3_2FlashAttentionStateInfo, layer_weight: NSAIndexerWeight) -> torch.Tensor: + def get_indices( + self, + hidden_states: torch.Tensor, + q_lora: torch.Tensor, + infer_state: Deepseek3_2FlashAttentionStateInfo, + layer_weight: NSAIndexerWeight, + ) -> torch.Tensor: q, k = self._get_q_k_bf16(hidden_states, q_lora, infer_state, layer_weight) q_fp8, q_scale = act_quant(q, self.block_size, self.scale_fmt) k_fp8, k_scale = act_quant(k, self.block_size, self.scale_fmt) destindex_copy_indexer_ks( - k_fp8, - k_scale, - infer_state.mem_index, - infer_state.indexer_ks_mem_manager.kv_buffer[self.layer_idx_] + k_fp8, k_scale, infer_state.mem_index, infer_state.indexer_ks_mem_manager.kv_buffer[self.layer_idx_] ) weights = layer_weight.weights_proj_.mm(hidden_states) * self.index_n_heads_scale @@ -87,34 +97,47 @@ def get_indices(self, hidden_states: torch.Tensor, q_lora: torch.Tensor, # Use efficient Triton kernel to extract FP8 keys and scales from buffer k_fp8_, k_scale_ = extract_indexer_ks( - infer_state.indexer_ks_mem_manager.kv_buffer[self.layer_idx_], - infer_state.req_all_mem_index + infer_state.indexer_ks_mem_manager.kv_buffer[self.layer_idx_], infer_state.req_all_mem_index ) - logits = deep_gemm.fp8_mqa_logits(q_fp8, (k_fp8_, k_scale_), weights.squeeze(-1), ks, ke) - + # Get actual sequence length from q (which comes from q_lora) + # This may differ from ks.shape[0] during certain operations + actual_seq_len = q.shape[0] + + # ks, ke, lengths, and weights should all match actual_seq_len + # Slice them if they don't match + if ks.shape[0] != actual_seq_len: + ks = ks[:actual_seq_len] + ke = ke[:actual_seq_len] + lengths = lengths[:actual_seq_len] + weights = weights[:actual_seq_len] + + logits = deep_gemm.fp8_mqa_logits(q_fp8, (k_fp8_, k_scale_), weights.squeeze(-1), ks, ke) + return fast_topk_transform_fused( - score=logits, - lengths=lengths, - page_table_size_1=page_table_1, - cu_seqlens_q=infer_state.cu_seqlens_q, + score=logits, + lengths=lengths, + page_table_size_1=page_table_1, + cu_seqlens_q=infer_state.cu_seqlens_q, topk=self.index_topk, ) - @staticmethod def _rotate_activation(x: torch.Tensor) -> torch.Tensor: assert x.dtype == torch.bfloat16 from sgl_kernel import hadamard_transform hidden_size = x.size(-1) - assert ( - hidden_size & (hidden_size - 1) - ) == 0, "Hidden size must be a power of 2 for Hadamard transform." - return hadamard_transform(x, scale=hidden_size**-0.5) - - def _get_q_k_bf16(self, hidden_states: torch.Tensor, q_lora: torch.Tensor, - infer_state: Deepseek3_2FlashAttentionStateInfo, layer_weight: NSAIndexerWeight): + assert (hidden_size & (hidden_size - 1)) == 0, "Hidden size must be a power of 2 for Hadamard transform." + return hadamard_transform(x, scale=hidden_size ** -0.5) + + def _get_q_k_bf16( + self, + hidden_states: torch.Tensor, + q_lora: torch.Tensor, + infer_state: Deepseek3_2FlashAttentionStateInfo, + layer_weight: NSAIndexerWeight, + ): q = layer_weight.wq_b_proj_.mm(q_lora).view(-1, self.index_n_heads, self.index_head_dim) k = layer_weight.wk_proj_.mm(hidden_states) @@ -123,11 +146,13 @@ def _get_q_k_bf16(self, hidden_states: torch.Tensor, q_lora: torch.Tensor, k.float(), (self.index_head_dim,), layer_weight.k_norm_.weight, layer_weight.k_norm_.bias, self.eps ).type_as(k) + # Slice position_cos and position_sin to match actual token length + actual_seq_len = q.shape[0] rotary_emb_fwd( q[:, :, : self.qk_rope_head_dim], k[:, None, : self.qk_rope_head_dim], - infer_state.position_cos, - infer_state.position_sin, + infer_state.position_cos[:actual_seq_len], + infer_state.position_sin[:actual_seq_len], ) q = self._rotate_activation(q) diff --git a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py index df52204270..cf748bcdb6 100644 --- a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py @@ -19,11 +19,7 @@ def __init__(self, layer_num, network_config, mode=[]): self.index_topk = network_config["index_topk"] super().__init__(layer_num, network_config, mode) - self.indexer = NSAIndexerInfer( - layer_idx=self.layer_num_, - network_config=self.network_config_, - mode=mode - ) + self.indexer = NSAIndexerInfer(layer_idx=self.layer_num_, network_config=self.network_config_, mode=mode) self.topk_indices = None return @@ -41,6 +37,9 @@ def _get_qkv( ) q = rmsnorm_forward(q, weight=layer_weight.q_a_layernorm_.weight, eps=self.eps_) + # Process all tokens for indexer + # Note: Prefix cache slicing optimization is disabled due to batch structure + # mismatch issues with fast_topk_transform_fused kernel self.topk_indices = self.indexer.get_indices(input, q, infer_state, layer_weight.indexer_layer_weight) q = layer_weight.q_b_proj_.mm(q) @@ -81,12 +80,12 @@ def _nsa_context_attention_kernel( layer_weight: Deepseek3_2TransformerLayerWeight, out=None, ) -> torch.Tensor: - + q_nope, q_rope = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1) q_all = torch.cat([q_nope, q_rope], dim=-1) mla_out, _, _ = flash_mla_sparse_fwd( - q=q_all, + q=q_all, kv=infer_state.mem_manager.kv_buffer[self.layer_num_], indices=self.topk_indices.unsqueeze(1), sm_scale=self.softmax_scale, @@ -95,7 +94,11 @@ def _nsa_context_attention_kernel( return mla_out def _nsa_token_attention_kernel( - self, q, infer_state: Deepseek3_2FlashAttentionStateInfo, layer_weight: Deepseek3_2TransformerLayerWeight, out=None + self, + q, + infer_state: Deepseek3_2FlashAttentionStateInfo, + layer_weight: Deepseek3_2TransformerLayerWeight, + out=None, ): q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :] q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1) @@ -104,16 +107,16 @@ def _nsa_token_attention_kernel( kv_nope = kv[:, :, : -self.qk_rope_head_dim].reshape(-1, 1, 1, self.kv_lora_rank) o_tensor = flash_attn_with_kvcache( - q=q_rope, - k_cache=k_rope, + q=q_rope, + k_cache=k_rope, v_cache=kv_nope, - qv=q_nope, - page_table=self.topk_indices, - cache_seqlens=infer_state.nsa_cache_seqlens, + qv=q_nope, + page_table=self.topk_indices, + cache_seqlens=infer_state.nsa_cache_seqlens, cu_seqlens_q=infer_state.cu_seqlens_q, - cu_seqlens_k_new=infer_state.nsa_cu_seqlens_k, + cu_seqlens_k_new=infer_state.nsa_cu_seqlens_k, max_seqlen_q=infer_state.max_q_seq_len, softmax_scale=self.softmax_scale, causal=True, ) - return o_tensor \ No newline at end of file + return o_tensor diff --git a/lightllm/models/deepseek3_2/triton_kernel/copy_indexer_ks.py b/lightllm/models/deepseek3_2/triton_kernel/copy_indexer_ks.py new file mode 100644 index 0000000000..93cf463eb0 --- /dev/null +++ b/lightllm/models/deepseek3_2/triton_kernel/copy_indexer_ks.py @@ -0,0 +1,232 @@ +import torch + +import triton +import triton.language as tl + + +@triton.jit +def _fwd_kernel_copy_indexer_ks( + buffer, # [large_size, 1, 132] uint8 + src_loc, # [copy_len] int32/int64 - source indices + dest_loc, # [copy_len] int32/int64 - destination indices + stride_bs, + stride_h, + stride_d, + BLOCK_KV: tl.constexpr, # = 128 (FP8 data) + BLOCK_SCALE: tl.constexpr, # = 4 (scale data) +): + """ + Triton kernel to copy indexer_ks data from source locations to destination locations. + + This kernel copies 132-byte indexer_ks entries (128 bytes FP8 key + 4 bytes float32 scale) + from source positions to destination positions within the same buffer. + + Args: + buffer: Shared buffer containing indexer_ks data [large_size, 1, 132] uint8 + src_loc: Source indices to copy from [copy_len] + dest_loc: Destination indices to copy to [copy_len] + stride_bs, stride_h, stride_d: Strides for the buffer + BLOCK_KV: Size of FP8 key data (128 bytes) + BLOCK_SCALE: Size of scale data (4 bytes) + """ + cur_index = tl.program_id(0) + offs_kv = tl.arange(0, BLOCK_KV) + offs_scale = tl.arange(0, BLOCK_SCALE) + + # Load source and destination indices + src_index = tl.load(src_loc + cur_index).to(tl.int64) + dest_index = tl.load(dest_loc + cur_index).to(tl.int64) + + # Copy FP8 key data (128 bytes) + src_kv_ptrs = buffer + src_index * stride_bs + stride_d * offs_kv + dest_kv_ptrs = buffer + dest_index * stride_bs + stride_d * offs_kv + kv_data = tl.load(src_kv_ptrs) + tl.store(dest_kv_ptrs, kv_data) + + # Copy scale data (4 bytes at offset 128) + src_scale_base = buffer + src_index * stride_bs + BLOCK_KV * stride_d + dest_scale_base = buffer + dest_index * stride_bs + BLOCK_KV * stride_d + scale_data = tl.load(src_scale_base + offs_scale * stride_d) + tl.store(dest_scale_base + offs_scale * stride_d, scale_data) + + return + + +@torch.no_grad() +def copy_indexer_ks( + buffer: torch.Tensor, + src_loc: torch.Tensor, + dest_loc: torch.Tensor, +): + """ + Copy indexer_ks data from source positions to destination positions. + + This function is used to copy cached tokens' indexer_ks data to new locations + after prefix cache matching. It ensures that the indexer_ks buffer stays + consistent with the KV cache buffer. + + Args: + buffer: [large_size, 1, 132] torch.uint8 + Buffer containing indexer_ks data (same buffer for src and dest) + src_loc: [copy_len] torch.int32 or torch.int64 + Source indices in buffer (old positions) + dest_loc: [copy_len] torch.int32 or torch.int64 + Destination indices in buffer (new positions) + + Returns: + None (modifies buffer in-place) + + Example: + >>> buffer = torch.zeros((1024, 1, 132), dtype=torch.uint8).cuda() + >>> old_pos = torch.tensor([100, 101, 102], dtype=torch.int32).cuda() + >>> new_pos = torch.tensor([200, 201, 202], dtype=torch.int32).cuda() + >>> copy_indexer_ks(buffer, old_pos, new_pos) + # Data from positions [100, 101, 102] is now copied to [200, 201, 202] + """ + copy_len = src_loc.shape[0] + block_kv = 128 # FP8 key data size + block_scale = 4 # Float32 scale size + + assert ( + src_loc.shape[0] == dest_loc.shape[0] + ), f"src_loc and dest_loc must have same length: {src_loc.shape[0]} != {dest_loc.shape[0]}" + assert ( + buffer.shape[2] == block_kv + block_scale + ), f"Expected buffer last dim={block_kv + block_scale}, got {buffer.shape[2]}" + assert buffer.dtype == torch.uint8, f"Expected buffer dtype=uint8, got {buffer.dtype}" + + grid = (copy_len,) + num_warps = 1 + + _fwd_kernel_copy_indexer_ks[grid]( + buffer, + src_loc, + dest_loc, + buffer.stride(0), + buffer.stride(1), + buffer.stride(2), + BLOCK_KV=block_kv, + BLOCK_SCALE=block_scale, + num_warps=num_warps, + num_stages=1, + ) + + return + + +def test_copy_indexer_ks(): + """Test the copy_indexer_ks kernel""" + import torch.nn.functional as F + from lightllm.models.deepseek3_2.triton_kernel.destindex_copy_indexer_ks import destindex_copy_indexer_ks + from lightllm.models.deepseek3_2.triton_kernel.extract_indexer_ks import extract_indexer_ks + + print("=" * 80) + print("Testing copy_indexer_ks") + print("=" * 80) + + # Test parameters + cached_len = 20 + buffer_size = 1024 + head_dim = 128 + dtype = torch.bfloat16 + fp8_type = torch.float8_e4m3fn + + # Create indexer_ks data + k_bf16 = torch.randn((cached_len, head_dim), dtype=dtype, device="cuda") + + # Quantize to FP8 + k_abs_max = k_bf16.abs().max(dim=1, keepdim=True)[0].clamp(min=1e-12) + k_scale = (k_abs_max / torch.finfo(fp8_type).max).to(torch.float32) + k_fp8 = (k_bf16 / k_abs_max).clamp(torch.finfo(fp8_type).min, torch.finfo(fp8_type).max).to(fp8_type) + + # Write to old positions + old_positions = torch.arange(100, 100 + cached_len, dtype=torch.int32, device="cuda") + buffer = torch.zeros((buffer_size, 1, 132), dtype=torch.uint8, device="cuda") + destindex_copy_indexer_ks(k_fp8, k_scale, old_positions, buffer) + + # Copy to new positions + new_positions = torch.arange(200, 200 + cached_len, dtype=torch.int32, device="cuda") + copy_indexer_ks(buffer, old_positions, new_positions) + + # Verify data at new positions matches original + k_fp8_extracted, k_scale_extracted = extract_indexer_ks(buffer, new_positions) + + fp8_match = torch.allclose(k_fp8_extracted.to(torch.float32), k_fp8.to(torch.float32), atol=0, rtol=0) + + scale_match = torch.allclose(k_scale_extracted, k_scale.squeeze(-1), atol=1e-6, rtol=1e-5) + + # Check dequantized values + k_dequant_extracted = k_fp8_extracted.to(dtype) * k_scale_extracted.unsqueeze(-1) + cosine_sim = F.cosine_similarity(k_dequant_extracted, k_bf16, dim=-1).mean() + + print(f"Cached tokens: {cached_len}, Head dim: {head_dim}") + print(f" FP8 values match: {fp8_match}") + print(f" Scale values match: {scale_match}") + print(f" Cosine similarity after dequantization: {cosine_sim:.6f}") + + assert fp8_match, "FP8 values do not match!" + assert scale_match, "Scale values do not match!" + assert cosine_sim > 0.99, f"Cosine similarity too low: {cosine_sim}" + + print("✓ Basic test passed!") + print() + + # Test with sequential indices + print("Testing sequential indices...") + old_pos_seq = torch.arange(20, dtype=torch.int32, device="cuda") + new_pos_seq = torch.arange(200, 220, dtype=torch.int32, device="cuda") + + k_bf16_seq = torch.randn((20, head_dim), dtype=dtype, device="cuda") + k_abs_max_seq = k_bf16_seq.abs().max(dim=1, keepdim=True)[0].clamp(min=1e-12) + k_scale_seq = (k_abs_max_seq / torch.finfo(fp8_type).max).to(torch.float32) + k_fp8_seq = (k_bf16_seq / k_abs_max_seq).clamp(torch.finfo(fp8_type).min, torch.finfo(fp8_type).max).to(fp8_type) + + buffer_seq = torch.zeros((buffer_size, 1, 132), dtype=torch.uint8, device="cuda") + destindex_copy_indexer_ks(k_fp8_seq, k_scale_seq, old_pos_seq, buffer_seq) + copy_indexer_ks(buffer_seq, old_pos_seq, new_pos_seq) + + k_fp8_ext_seq, k_scale_ext_seq = extract_indexer_ks(buffer_seq, new_pos_seq) + + fp8_match_seq = torch.allclose(k_fp8_ext_seq.to(torch.float32), k_fp8_seq.to(torch.float32), atol=0, rtol=0) + scale_match_seq = torch.allclose(k_scale_ext_seq, k_scale_seq.squeeze(-1), atol=1e-6, rtol=1e-5) + + print(f" Sequential indices: FP8={fp8_match_seq}, Scale={scale_match_seq}") + assert fp8_match_seq and scale_match_seq + print("✓ Sequential test passed!") + print() + + # Test with single element + print("Testing single element...") + old_pos_single = torch.tensor([42], dtype=torch.int32, device="cuda") + new_pos_single = torch.tensor([424], dtype=torch.int32, device="cuda") + + k_bf16_single = torch.randn((1, head_dim), dtype=dtype, device="cuda") + k_abs_max_single = k_bf16_single.abs().max(dim=1, keepdim=True)[0].clamp(min=1e-12) + k_scale_single = (k_abs_max_single / torch.finfo(fp8_type).max).to(torch.float32) + k_fp8_single = ( + (k_bf16_single / k_abs_max_single).clamp(torch.finfo(fp8_type).min, torch.finfo(fp8_type).max).to(fp8_type) + ) + + buffer_single = torch.zeros((buffer_size, 1, 132), dtype=torch.uint8, device="cuda") + destindex_copy_indexer_ks(k_fp8_single, k_scale_single, old_pos_single, buffer_single) + copy_indexer_ks(buffer_single, old_pos_single, new_pos_single) + + k_fp8_ext_single, k_scale_ext_single = extract_indexer_ks(buffer_single, new_pos_single) + + fp8_match_single = torch.allclose( + k_fp8_ext_single.to(torch.float32), k_fp8_single.to(torch.float32), atol=0, rtol=0 + ) + scale_match_single = torch.allclose(k_scale_ext_single, k_scale_single.squeeze(-1), atol=1e-6, rtol=1e-5) + + print(f" Single element: FP8={fp8_match_single}, Scale={scale_match_single}") + assert fp8_match_single and scale_match_single + print("✓ Single element test passed!") + print() + + print("=" * 80) + print("All tests passed successfully! ✓") + print("=" * 80) + + +if __name__ == "__main__": + test_copy_indexer_ks() diff --git a/lightllm/models/deepseek3_2/triton_kernel/destindex_copy_indexer_ks.py b/lightllm/models/deepseek3_2/triton_kernel/destindex_copy_indexer_ks.py index 46095bfb75..8faf3cdea4 100644 --- a/lightllm/models/deepseek3_2/triton_kernel/destindex_copy_indexer_ks.py +++ b/lightllm/models/deepseek3_2/triton_kernel/destindex_copy_indexer_ks.py @@ -21,55 +21,57 @@ def _fwd_kernel_destindex_copy_indexer_ks( ): """ Triton kernel to copy FP8 K values and their scales to an indexed output buffer. - + This kernel reads FP8 key values (128 dims) and their float32 scale values, then writes them to a compact buffer format where each entry contains: - Bytes 0-127: FP8 key values (128 bytes) - Bytes 128-131: Float32 scale (4 bytes) - + The destination location for each source element is specified by DestLoc. """ cur_index = tl.program_id(0) offs_d = tl.arange(0, BLOCK_DMODEL) - + # Load destination index for this thread dest_index = tl.load(DestLoc + cur_index).to(tl.int64) - + # Load K_fp8 (128 values) and K_scale (1 value) from source k_fp8_ptrs = K_fp8 + cur_index * stride_k_bs + stride_k_d * offs_d k_fp8 = tl.load(k_fp8_ptrs) - + k_scale = tl.load(K_scale + cur_index * stride_scale_bs) - + # Store K_fp8 to O_buffer[:, 0, :128] # Convert fp8 to uint8 through bitcast for storage in uint8 buffer o_k_ptrs = O_buffer + dest_index * stride_o_bs + stride_o_d * offs_d k_fp8_as_uint8 = k_fp8.to(tl.uint8, bitcast=True) tl.store(o_k_ptrs, k_fp8_as_uint8) - + # Store K_scale to O_buffer[:, 0, 128:132] (4 bytes for float32) # Convert float32 scale to 4 uint8 bytes using bitcast and bit manipulation o_scale_ptr = O_buffer + dest_index * stride_o_bs + BLOCK_DMODEL * stride_o_d scale_as_uint32 = k_scale.to(tl.float32, bitcast=True).to(tl.uint32, bitcast=True) - + # Store each byte of the float32 scale (little-endian) for i in range(4): byte_val = ((scale_as_uint32 >> (i * 8)) & 0xFF).to(tl.uint8) tl.store(o_scale_ptr + i * stride_o_d, byte_val) - + return @torch.no_grad() -def destindex_copy_indexer_ks(K_fp8: torch.Tensor, K_scale: torch.Tensor, DestLoc: torch.Tensor, O_buffer: torch.Tensor): +def destindex_copy_indexer_ks( + K_fp8: torch.Tensor, K_scale: torch.Tensor, DestLoc: torch.Tensor, O_buffer: torch.Tensor +): """ Copy FP8-quantized key values and their scales to indexed locations in a buffer. - + This function is used in the DeepSeek-V3.2 NSA (Neighbor-aware Sparse Attention) mechanism to store compressed key representations in a memory buffer. Each key is stored with its FP8 representation (128 bytes) followed by its float32 scale (4 bytes), for a total of 132 bytes per key. - + Args: K_fp8: [q_seq_len, 128] torch.fp8_e4m3fn FP8-quantized key values @@ -84,7 +86,7 @@ def destindex_copy_indexer_ks(K_fp8: torch.Tensor, K_scale: torch.Tensor, DestLo Returns: None (modifies O_buffer in-place) - + Example: >>> k_fp8 = torch.randn(50, 128).to(torch.float8_e4m3fn).cuda() >>> k_scale = torch.randn(50, 1).cuda() @@ -95,14 +97,21 @@ def destindex_copy_indexer_ks(K_fp8: torch.Tensor, K_scale: torch.Tensor, DestLo """ seq_len = DestLoc.shape[0] head_dim = K_fp8.shape[1] - + assert head_dim == 128, f"Expected head_dim=128, got {head_dim}" - assert K_scale.shape[0] == seq_len + + # Handle cases where tensor lengths don't match (e.g., during prefix cache) + actual_seq_len = min(K_scale.shape[0], seq_len) + if actual_seq_len < seq_len: + K_fp8 = K_fp8[:actual_seq_len] + K_scale = K_scale[:actual_seq_len] + DestLoc = DestLoc[:actual_seq_len] + assert O_buffer.shape[2] == 132, f"Expected O_buffer last dim=132, got {O_buffer.shape[2]}" - - grid = (seq_len,) + + grid = (actual_seq_len,) num_warps = 1 - + _fwd_kernel_destindex_copy_indexer_ks[grid]( K_fp8, K_scale, @@ -125,151 +134,125 @@ def destindex_copy_indexer_ks(K_fp8: torch.Tensor, K_scale: torch.Tensor, DestLo def test_destindex_copy_indexer_ks(): """Test the destindex_copy_indexer_ks kernel""" import torch.nn.functional as F - + print("=" * 80) print("Testing destindex_copy_indexer_ks") print("=" * 80) - + # Test parameters q_seq_len = 50 head_dim = 128 large_size = 1024 dtype = torch.bfloat16 fp8_type = torch.float8_e4m3fn - + # Create random destination indices dest_loc = torch.randint(0, large_size, (q_seq_len,), device="cuda", dtype=torch.int32).unique() actual_seq_len = len(dest_loc) - + # Create input tensors k_bf16 = torch.randn((actual_seq_len, head_dim), dtype=dtype, device="cuda") - + # Quantize to FP8 k_abs_max = k_bf16.abs().max(dim=1, keepdim=True)[0].clamp(min=1e-12) k_scale = (k_abs_max / torch.finfo(fp8_type).max).to(torch.float32) - k_fp8 = (k_bf16 / k_abs_max).clamp( - torch.finfo(fp8_type).min, torch.finfo(fp8_type).max - ).to(fp8_type) - + k_fp8 = (k_bf16 / k_abs_max).clamp(torch.finfo(fp8_type).min, torch.finfo(fp8_type).max).to(fp8_type) + # Create output buffer (as uint8 to allow reinterpretation) o_buffer_uint8 = torch.zeros((large_size, 1, 132), dtype=torch.uint8, device="cuda") - + # Run kernel destindex_copy_indexer_ks(k_fp8, k_scale, dest_loc, o_buffer_uint8) - + # Extract results k_fp8_out = o_buffer_uint8[:, 0, :128].view(fp8_type) - + # Extract scale by reinterpreting 4 bytes as float32 scale_bytes = o_buffer_uint8[:, 0, 128:132].contiguous() k_scale_out = scale_bytes.view(-1, 4).view(torch.float32).squeeze(-1) - + # Verify results at destination locations k_fp8_extracted = k_fp8_out[dest_loc] k_scale_extracted = k_scale_out[dest_loc] - + # Check FP8 values match - fp8_match = torch.allclose( - k_fp8_extracted.to(torch.float32), - k_fp8.to(torch.float32), - atol=0, rtol=0 - ) - + fp8_match = torch.allclose(k_fp8_extracted.to(torch.float32), k_fp8.to(torch.float32), atol=0, rtol=0) + # Check scales match - scale_match = torch.allclose( - k_scale_extracted, - k_scale.squeeze(-1), - atol=1e-6, rtol=1e-5 - ) - + scale_match = torch.allclose(k_scale_extracted, k_scale.squeeze(-1), atol=1e-6, rtol=1e-5) + # Check dequantized values k_dequant_out = k_fp8_extracted.to(dtype) * k_scale_extracted.unsqueeze(-1) cosine_sim = F.cosine_similarity(k_dequant_out, k_bf16, dim=-1).mean() - + print(f"Test with seq_len={actual_seq_len}, head_dim={head_dim}") print(f" FP8 values match: {fp8_match}") print(f" Scale values match: {scale_match}") print(f" Cosine similarity after dequantization: {cosine_sim:.6f}") - + assert fp8_match, "FP8 values do not match!" assert scale_match, "Scale values do not match!" assert cosine_sim > 0.99, f"Cosine similarity too low: {cosine_sim}" - + print("✓ Basic test passed!") print() - + # Test edge cases print("Testing edge cases...") - + # Test with sequential indices dest_loc_seq = torch.arange(20, device="cuda", dtype=torch.int32) k_bf16_seq = torch.randn((20, head_dim), dtype=dtype, device="cuda") k_abs_max_seq = k_bf16_seq.abs().max(dim=1, keepdim=True)[0].clamp(min=1e-12) k_scale_seq = (k_abs_max_seq / torch.finfo(fp8_type).max).to(torch.float32) - k_fp8_seq = (k_bf16_seq / k_abs_max_seq).clamp( - torch.finfo(fp8_type).min, torch.finfo(fp8_type).max - ).to(fp8_type) - + k_fp8_seq = (k_bf16_seq / k_abs_max_seq).clamp(torch.finfo(fp8_type).min, torch.finfo(fp8_type).max).to(fp8_type) + o_buffer_seq = torch.zeros((large_size, 1, 132), dtype=torch.uint8, device="cuda") destindex_copy_indexer_ks(k_fp8_seq, k_scale_seq, dest_loc_seq, o_buffer_seq) - + k_fp8_out_seq = o_buffer_seq[:20, 0, :128].view(fp8_type) scale_bytes_seq = o_buffer_seq[:20, 0, 128:132].contiguous() k_scale_out_seq = scale_bytes_seq.view(-1, 4).view(torch.float32).squeeze(-1) - - fp8_match_seq = torch.allclose( - k_fp8_out_seq.to(torch.float32), - k_fp8_seq.to(torch.float32), - atol=0, rtol=0 - ) - scale_match_seq = torch.allclose( - k_scale_out_seq, - k_scale_seq.squeeze(-1), - atol=1e-6, rtol=1e-5 - ) - + + fp8_match_seq = torch.allclose(k_fp8_out_seq.to(torch.float32), k_fp8_seq.to(torch.float32), atol=0, rtol=0) + scale_match_seq = torch.allclose(k_scale_out_seq, k_scale_seq.squeeze(-1), atol=1e-6, rtol=1e-5) + print(f" Sequential indices test: FP8={fp8_match_seq}, Scale={scale_match_seq}") assert fp8_match_seq and scale_match_seq print("✓ Edge case tests passed!") print() - + # Test with single element print("Testing single element...") dest_loc_single = torch.tensor([42], device="cuda", dtype=torch.int32) k_bf16_single = torch.randn((1, head_dim), dtype=dtype, device="cuda") k_abs_max_single = k_bf16_single.abs().max(dim=1, keepdim=True)[0].clamp(min=1e-12) k_scale_single = (k_abs_max_single / torch.finfo(fp8_type).max).to(torch.float32) - k_fp8_single = (k_bf16_single / k_abs_max_single).clamp( - torch.finfo(fp8_type).min, torch.finfo(fp8_type).max - ).to(fp8_type) - + k_fp8_single = ( + (k_bf16_single / k_abs_max_single).clamp(torch.finfo(fp8_type).min, torch.finfo(fp8_type).max).to(fp8_type) + ) + o_buffer_single = torch.zeros((large_size, 1, 132), dtype=torch.uint8, device="cuda") destindex_copy_indexer_ks(k_fp8_single, k_scale_single, dest_loc_single, o_buffer_single) - + k_fp8_out_single = o_buffer_single[42:43, 0, :128].view(fp8_type) scale_bytes_single = o_buffer_single[42:43, 0, 128:132].contiguous() k_scale_out_single = scale_bytes_single.view(-1, 4).view(torch.float32).squeeze(-1) - + fp8_match_single = torch.allclose( - k_fp8_out_single.to(torch.float32), - k_fp8_single.to(torch.float32), - atol=0, rtol=0 + k_fp8_out_single.to(torch.float32), k_fp8_single.to(torch.float32), atol=0, rtol=0 ) - scale_match_single = torch.allclose( - k_scale_out_single, - k_scale_single.squeeze(-1), - atol=1e-6, rtol=1e-5 - ) - + scale_match_single = torch.allclose(k_scale_out_single, k_scale_single.squeeze(-1), atol=1e-6, rtol=1e-5) + print(f" Single element test: FP8={fp8_match_single}, Scale={scale_match_single}") assert fp8_match_single and scale_match_single print("✓ Single element test passed!") print() - + print("=" * 80) print("All tests passed successfully! ✓") print("=" * 80) if __name__ == "__main__": - test_destindex_copy_indexer_ks() \ No newline at end of file + test_destindex_copy_indexer_ks() From 17204e75c845a49677267e3c68bb273d30e1b0e4 Mon Sep 17 00:00:00 2001 From: sufubao Date: Mon, 2 Feb 2026 11:45:47 +0000 Subject: [PATCH 14/58] abstract NSA attention into backend framework Add NSA (Native Sparse Attention) backend abstraction following the existing MLA pattern. This enables future support for multiple NSA implementations (flashmla_sparse, fa3, tilelang, aiter). - Add attention framework from origin/main with NSA extensions - Create NsaFlashMlaSparseAttBackend with prefill/decode states - Extend AttControl with nsa_prefill/nsa_decode params - Add factory functions get_nsa_*_att_backend_class() - Refactor DeepSeek V3.2 to use NSA backend - Add missing envs_utils functions for compatibility --- .../common/basemodel/attention/__init__.py | 5 + .../common/basemodel/attention/base_att.py | 5 + .../basemodel/attention/create_utils.py | 47 +++++ .../basemodel/attention/nsa/__init__.py | 13 ++ .../attention/nsa/flashmla_sparse.py | 172 ++++++++++++++++++ .../layer_infer/transformer_layer_infer.py | 81 ++++++--- 6 files changed, 299 insertions(+), 24 deletions(-) create mode 100644 lightllm/common/basemodel/attention/nsa/__init__.py create mode 100644 lightllm/common/basemodel/attention/nsa/flashmla_sparse.py diff --git a/lightllm/common/basemodel/attention/__init__.py b/lightllm/common/basemodel/attention/__init__.py index 80df545498..0eea52cc89 100644 --- a/lightllm/common/basemodel/attention/__init__.py +++ b/lightllm/common/basemodel/attention/__init__.py @@ -10,9 +10,14 @@ from .flashinfer.fp import FlashInferAttBackend from .flashinfer.mla import MlaFlashInferAttBackend +# NSA backend +from .nsa.flashmla_sparse import NsaFlashMlaSparseAttBackend + from .create_utils import ( get_prefill_att_backend_class, get_decode_att_backend_class, get_mla_prefill_att_backend_class, get_mla_decode_att_backend_class, + get_nsa_prefill_att_backend_class, + get_nsa_decode_att_backend_class, ) diff --git a/lightllm/common/basemodel/attention/base_att.py b/lightllm/common/basemodel/attention/base_att.py index 859d97ca84..1286a46ec2 100644 --- a/lightllm/common/basemodel/attention/base_att.py +++ b/lightllm/common/basemodel/attention/base_att.py @@ -65,6 +65,11 @@ class AttControl: mla_prefill_dict: Dict = None mla_decode: bool = False mla_decode_dict: Dict = None + # nsa (native sparse attention) 专用传参项 + nsa_prefill: bool = False + nsa_prefill_dict: Dict = None + nsa_decode: bool = False + nsa_decode_dict: Dict = None @dataclass diff --git a/lightllm/common/basemodel/attention/create_utils.py b/lightllm/common/basemodel/attention/create_utils.py index 19252cf13a..dd38028951 100644 --- a/lightllm/common/basemodel/attention/create_utils.py +++ b/lightllm/common/basemodel/attention/create_utils.py @@ -17,6 +17,9 @@ from .flashinfer.fp import FlashInferAttBackend from .flashinfer.mla import MlaFlashInferAttBackend +# NSA backend +from .nsa.flashmla_sparse import NsaFlashMlaSparseAttBackend + logger = init_logger(__name__) # Backend class mappings by data type @@ -46,6 +49,14 @@ }, } +# NSA (Native Sparse Attention) backend mappings +nsa_data_type_to_backend = { + "None": { + "flashmla_sparse": NsaFlashMlaSparseAttBackend, + # Future backends: "fa3", "tilelang", "aiter" + }, +} + def _auto_select_backend( llm_dtype: str, is_mla: bool = False, priority_list: list = ["fa3", "flashinfer", "triton"] @@ -105,3 +116,39 @@ def get_mla_decode_att_backend_class(index=0, priority_list: list = ["fa3", "fla return mla_data_type_to_backend[llm_dtype][backend_str] else: return _auto_select_backend(llm_dtype, is_mla=True, priority_list=priority_list) + + +def get_nsa_prefill_att_backend_class(backend_str: str = "flashmla_sparse") -> BaseAttBackend: + """Get NSA prefill attention backend class. + + Args: + backend_str: Backend name, currently only "flashmla_sparse" is supported. + Future options: "fa3", "tilelang", "aiter" + + Returns: + NSA attention backend class + """ + # NSA currently only supports "None" dtype (no quantization) + llm_dtype = "None" + if backend_str not in nsa_data_type_to_backend[llm_dtype]: + logger.warning(f"NSA backend '{backend_str}' not found, falling back to flashmla_sparse") + backend_str = "flashmla_sparse" + return nsa_data_type_to_backend[llm_dtype][backend_str] + + +def get_nsa_decode_att_backend_class(backend_str: str = "flashmla_sparse") -> BaseAttBackend: + """Get NSA decode attention backend class. + + Args: + backend_str: Backend name, currently only "flashmla_sparse" is supported. + Future options: "fa3", "tilelang", "aiter" + + Returns: + NSA attention backend class + """ + # NSA currently only supports "None" dtype (no quantization) + llm_dtype = "None" + if backend_str not in nsa_data_type_to_backend[llm_dtype]: + logger.warning(f"NSA backend '{backend_str}' not found, falling back to flashmla_sparse") + backend_str = "flashmla_sparse" + return nsa_data_type_to_backend[llm_dtype][backend_str] diff --git a/lightllm/common/basemodel/attention/nsa/__init__.py b/lightllm/common/basemodel/attention/nsa/__init__.py new file mode 100644 index 0000000000..11a1ebfdcd --- /dev/null +++ b/lightllm/common/basemodel/attention/nsa/__init__.py @@ -0,0 +1,13 @@ +"""NSA (Native Sparse Attention) backend implementations.""" + +from .flashmla_sparse import ( + NsaFlashMlaSparseAttBackend, + NsaFlashMlaSparsePrefillAttState, + NsaFlashMlaSparseDecodeAttState, +) + +__all__ = [ + "NsaFlashMlaSparseAttBackend", + "NsaFlashMlaSparsePrefillAttState", + "NsaFlashMlaSparseDecodeAttState", +] diff --git a/lightllm/common/basemodel/attention/nsa/flashmla_sparse.py b/lightllm/common/basemodel/attention/nsa/flashmla_sparse.py new file mode 100644 index 0000000000..8e52499998 --- /dev/null +++ b/lightllm/common/basemodel/attention/nsa/flashmla_sparse.py @@ -0,0 +1,172 @@ +"""NSA FlashMLA-sparse attention backend implementation. + +This backend uses sgl_kernel's flash_mla_sparse_fwd for prefill +and flash_attn_with_kvcache for decode with sparse indices. +""" + +import dataclasses +import torch +from typing import Tuple, TYPE_CHECKING + +from ..base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl +from lightllm.utils.dist_utils import get_current_device_id + +if TYPE_CHECKING: + from lightllm.common.basemodel.infer_struct import InferStateInfo + + +class NsaFlashMlaSparseAttBackend(BaseAttBackend): + """NSA backend using FlashMLA sparse kernels from sgl_kernel.""" + + def __init__(self, model): + super().__init__(model=model) + + def create_att_prefill_state(self, infer_state: "InferStateInfo") -> "NsaFlashMlaSparsePrefillAttState": + return NsaFlashMlaSparsePrefillAttState(backend=self, infer_state=infer_state) + + def create_att_decode_state(self, infer_state: "InferStateInfo") -> "NsaFlashMlaSparseDecodeAttState": + return NsaFlashMlaSparseDecodeAttState(backend=self, infer_state=infer_state) + + +@dataclasses.dataclass +class NsaFlashMlaSparsePrefillAttState(BasePrefillAttState): + """Prefill attention state for NSA using flash_mla_sparse_fwd.""" + + cu_seqlens_q: torch.Tensor = None + cu_seqlens_k: torch.Tensor = None + + def init_state(self): + self.cu_seqlens_q = self.infer_state.b1_cu_q_seq_len.int() + self.cu_seqlens_k = self.infer_state.b1_cu_kv_seq_len.int() + + def prefill_att( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + att_control: AttControl = AttControl(), + alloc_func=torch.empty, + ) -> torch.Tensor: + """Execute NSA prefill attention. + + Args: + q: Query tensor [total_tokens, num_heads, head_dim] - already projected with k_b_proj + k: KV buffer tensor from memory manager + v: Not used for NSA (pass None) + att_control: Must have nsa_prefill=True and nsa_prefill_dict with: + - topk_indices: Sparse attention indices [total_tokens, topk] + - softmax_scale: Attention softmax scale + - kv_lora_rank: d_v dimension for MLA + + Returns: + Output tensor [total_tokens, num_heads, kv_lora_rank] + """ + assert att_control.nsa_prefill, "nsa_prefill must be True for NSA prefill attention" + assert att_control.nsa_prefill_dict is not None, "nsa_prefill_dict is required" + + return self._nsa_prefill_att(q=q, kv=k, att_control=att_control) + + def _nsa_prefill_att( + self, + q: torch.Tensor, + kv: torch.Tensor, + att_control: AttControl, + ) -> torch.Tensor: + from sgl_kernel.flash_mla import flash_mla_sparse_fwd + + nsa_dict = att_control.nsa_prefill_dict + topk_indices = nsa_dict["topk_indices"] + softmax_scale = nsa_dict["softmax_scale"] + kv_lora_rank = nsa_dict["kv_lora_rank"] + + # flash_mla_sparse_fwd expects indices with shape [total_tokens, 1, topk] + if topk_indices.ndim == 2: + topk_indices = topk_indices.unsqueeze(1) + + mla_out, _, _ = flash_mla_sparse_fwd( + q=q, + kv=kv, + indices=topk_indices, + sm_scale=softmax_scale, + d_v=kv_lora_rank, + ) + return mla_out + + +@dataclasses.dataclass +class NsaFlashMlaSparseDecodeAttState(BaseDecodeAttState): + """Decode attention state for NSA using flash_attn_with_kvcache.""" + + cu_seqlens_q: torch.Tensor = None + cu_seqlens_k: torch.Tensor = None + + def init_state(self): + self.cu_seqlens_q = self.infer_state.b1_cu_q_seq_len.int() + self.cu_seqlens_k = self.infer_state.b1_cu_kv_seq_len.int() + + def decode_att( + self, + q: Tuple[torch.Tensor, torch.Tensor], + k: torch.Tensor, + v: torch.Tensor, + att_control: AttControl = AttControl(), + alloc_func=torch.empty, + ) -> torch.Tensor: + """Execute NSA decode attention. + + Args: + q: Tuple of (q_nope, q_rope) tensors + k: KV buffer tensor from memory manager + v: Not used for NSA (pass None) + att_control: Must have nsa_decode=True and nsa_decode_dict with: + - topk_indices: Page table for sparse attention [batch, topk] + - nsa_cache_seqlens: Cache sequence lengths for NSA + - nsa_cu_seqlens_k: Cumulative sequence lengths for NSA + - softmax_scale: Attention softmax scale + - kv_lora_rank: d_v dimension for MLA + - qk_rope_head_dim: Rope head dimension + + Returns: + Output tensor + """ + assert att_control.nsa_decode, "nsa_decode must be True for NSA decode attention" + assert att_control.nsa_decode_dict is not None, "nsa_decode_dict is required" + + return self._nsa_decode_att(q=q, kv=k, att_control=att_control) + + def _nsa_decode_att( + self, + q: Tuple[torch.Tensor, torch.Tensor], + kv: torch.Tensor, + att_control: AttControl, + ) -> torch.Tensor: + from sgl_kernel.flash_attn import flash_attn_with_kvcache + + nsa_dict = att_control.nsa_decode_dict + topk_indices = nsa_dict["topk_indices"] + nsa_cache_seqlens = nsa_dict["nsa_cache_seqlens"] + nsa_cu_seqlens_k = nsa_dict["nsa_cu_seqlens_k"] + softmax_scale = nsa_dict["softmax_scale"] + kv_lora_rank = nsa_dict["kv_lora_rank"] + qk_rope_head_dim = nsa_dict["qk_rope_head_dim"] + + q_nope, q_rope = q + + # Extract k_rope and kv_nope from the KV buffer + k_rope = kv[:, :, -qk_rope_head_dim:].reshape(-1, 1, 1, qk_rope_head_dim) + kv_nope = kv[:, :, :-qk_rope_head_dim].reshape(-1, 1, 1, kv_lora_rank) + + o_tensor = flash_attn_with_kvcache( + q=q_rope, + k_cache=k_rope, + v_cache=kv_nope, + qv=q_nope, + page_table=topk_indices, + cache_seqlens=nsa_cache_seqlens, + cu_seqlens_q=self.cu_seqlens_q, + cu_seqlens_k_new=nsa_cu_seqlens_k, + max_seqlen_q=self.infer_state.max_q_seq_len, + softmax_scale=softmax_scale, + causal=True, + ) + return o_tensor diff --git a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py index cf748bcdb6..1abb749a8e 100644 --- a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py @@ -2,8 +2,6 @@ from typing import override import torch -from sgl_kernel.flash_mla import flash_mla_sparse_fwd -from sgl_kernel.flash_attn import flash_attn_with_kvcache from lightllm.models.deepseek2.layer_infer.transformer_layer_infer import Deepseek2TransformerLayerInfer from lightllm.models.deepseek3_2.layer_infer.nsa_indexer_layer_inder import NSAIndexerInfer @@ -12,6 +10,8 @@ from lightllm.models.deepseek3_2.triton_kernel.token_group_quant import per_token_group_quant_mla_deep_gemm_masked_fp8 from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward from lightllm.models.deepseek2.triton_kernel.rotary_emb import rotary_emb_fwd +from lightllm.common.basemodel.attention.base_att import AttControl +from lightllm.common.basemodel.attention.create_utils import get_nsa_prefill_att_backend_class class Deepseek3_2TransformerLayerInfer(Deepseek2TransformerLayerInfer): @@ -21,8 +21,19 @@ def __init__(self, layer_num, network_config, mode=[]): self.indexer = NSAIndexerInfer(layer_idx=self.layer_num_, network_config=self.network_config_, mode=mode) self.topk_indices = None + + # Initialize NSA attention backend (singleton, lazy initialization) + self._nsa_backend_class = get_nsa_prefill_att_backend_class() + self._nsa_backend = None return + def _get_nsa_backend(self): + """Get or create the NSA backend (lazy initialization).""" + if self._nsa_backend is None: + # NSA backend doesn't require model reference for basic operations + self._nsa_backend = self._nsa_backend_class(model=None) + return self._nsa_backend + @override def _get_qkv( self, @@ -80,16 +91,30 @@ def _nsa_context_attention_kernel( layer_weight: Deepseek3_2TransformerLayerWeight, out=None, ) -> torch.Tensor: - + # Model-specific q projection (uses layer weights) q_nope, q_rope = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1) q_all = torch.cat([q_nope, q_rope], dim=-1) - mla_out, _, _ = flash_mla_sparse_fwd( + + # Use NSA backend for attention computation + att_control = AttControl( + nsa_prefill=True, + nsa_prefill_dict={ + "topk_indices": self.topk_indices, + "softmax_scale": self.softmax_scale, + "kv_lora_rank": self.kv_lora_rank, + }, + ) + + # Create prefill state and execute attention + nsa_backend = self._get_nsa_backend() + prefill_state = nsa_backend.create_att_prefill_state(infer_state) + prefill_state.init_state() + mla_out = prefill_state.prefill_att( q=q_all, - kv=infer_state.mem_manager.kv_buffer[self.layer_num_], - indices=self.topk_indices.unsqueeze(1), - sm_scale=self.softmax_scale, - d_v=self.kv_lora_rank, + k=infer_state.mem_manager.kv_buffer[self.layer_num_], + v=None, + att_control=att_control, ) return mla_out @@ -100,23 +125,31 @@ def _nsa_token_attention_kernel( layer_weight: Deepseek3_2TransformerLayerWeight, out=None, ): + # Model-specific q projection (uses layer weights) q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :] q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1) - kv = infer_state.mem_manager.kv_buffer[self.layer_num_] - k_rope = kv[:, :, -self.qk_rope_head_dim :].reshape(-1, 1, 1, self.qk_rope_head_dim) - kv_nope = kv[:, :, : -self.qk_rope_head_dim].reshape(-1, 1, 1, self.kv_lora_rank) - - o_tensor = flash_attn_with_kvcache( - q=q_rope, - k_cache=k_rope, - v_cache=kv_nope, - qv=q_nope, - page_table=self.topk_indices, - cache_seqlens=infer_state.nsa_cache_seqlens, - cu_seqlens_q=infer_state.cu_seqlens_q, - cu_seqlens_k_new=infer_state.nsa_cu_seqlens_k, - max_seqlen_q=infer_state.max_q_seq_len, - softmax_scale=self.softmax_scale, - causal=True, + + # Use NSA backend for attention computation + att_control = AttControl( + nsa_decode=True, + nsa_decode_dict={ + "topk_indices": self.topk_indices, + "nsa_cache_seqlens": infer_state.nsa_cache_seqlens, + "nsa_cu_seqlens_k": infer_state.nsa_cu_seqlens_k, + "softmax_scale": self.softmax_scale, + "kv_lora_rank": self.kv_lora_rank, + "qk_rope_head_dim": self.qk_rope_head_dim, + }, + ) + + # Create decode state and execute attention + nsa_backend = self._get_nsa_backend() + decode_state = nsa_backend.create_att_decode_state(infer_state) + decode_state.init_state() + o_tensor = decode_state.decode_att( + q=(q_nope, q_rope), + k=infer_state.mem_manager.kv_buffer[self.layer_num_], + v=None, + att_control=att_control, ) return o_tensor From b52f3db4f7dd8a074aa93519b32fed361a222dba Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Wed, 4 Mar 2026 01:43:51 +0000 Subject: [PATCH 15/58] rebase --- .../basemodel/attention/create_utils.py | 23 - .../attention/nsa/flashmla_sparse.py | 41 -- lightllm/common/basemodel/basemodel.py | 20 - lightllm/common/infer_utils.py | 63 +-- .../deepseek2_mem_manager.py | 4 +- .../kv_cache_mem_manager/mem_manager.py | 30 +- lightllm/models/deepseek3_2/infer_struct.py | 196 +++---- .../layer_infer/nsa_indexer_layer_inder.py | 6 +- .../layer_infer/transformer_layer_infer.py | 18 +- .../layer_weights/nsa_indexer_layer_weight.py | 15 +- lightllm/models/deepseek3_2/mem_manager.py | 43 +- lightllm/models/deepseek3_2/model.py | 4 +- .../triton_kernel/copy_indexer_ks.py | 232 -------- lightllm/server/api_cli.py | 2 +- lightllm/server/api_openai.py | 33 ++ lightllm/server/core/objs/sampling_params.py | 45 +- lightllm/server/function_call_parser.py | 503 +++++++++++++++++- .../tool_chat_template_deepseekv32.jinjia | 301 +++++++---- test/test_api/test_gsmk.py | 265 +++++++++ 19 files changed, 1169 insertions(+), 675 deletions(-) delete mode 100644 lightllm/models/deepseek3_2/triton_kernel/copy_indexer_ks.py create mode 100644 test/test_api/test_gsmk.py diff --git a/lightllm/common/basemodel/attention/create_utils.py b/lightllm/common/basemodel/attention/create_utils.py index dd38028951..e3bf81daed 100644 --- a/lightllm/common/basemodel/attention/create_utils.py +++ b/lightllm/common/basemodel/attention/create_utils.py @@ -16,8 +16,6 @@ from .flashinfer.fp8 import Fp8FlashInferAttBackend from .flashinfer.fp import FlashInferAttBackend from .flashinfer.mla import MlaFlashInferAttBackend - -# NSA backend from .nsa.flashmla_sparse import NsaFlashMlaSparseAttBackend logger = init_logger(__name__) @@ -49,7 +47,6 @@ }, } -# NSA (Native Sparse Attention) backend mappings nsa_data_type_to_backend = { "None": { "flashmla_sparse": NsaFlashMlaSparseAttBackend, @@ -119,16 +116,6 @@ def get_mla_decode_att_backend_class(index=0, priority_list: list = ["fa3", "fla def get_nsa_prefill_att_backend_class(backend_str: str = "flashmla_sparse") -> BaseAttBackend: - """Get NSA prefill attention backend class. - - Args: - backend_str: Backend name, currently only "flashmla_sparse" is supported. - Future options: "fa3", "tilelang", "aiter" - - Returns: - NSA attention backend class - """ - # NSA currently only supports "None" dtype (no quantization) llm_dtype = "None" if backend_str not in nsa_data_type_to_backend[llm_dtype]: logger.warning(f"NSA backend '{backend_str}' not found, falling back to flashmla_sparse") @@ -137,16 +124,6 @@ def get_nsa_prefill_att_backend_class(backend_str: str = "flashmla_sparse") -> B def get_nsa_decode_att_backend_class(backend_str: str = "flashmla_sparse") -> BaseAttBackend: - """Get NSA decode attention backend class. - - Args: - backend_str: Backend name, currently only "flashmla_sparse" is supported. - Future options: "fa3", "tilelang", "aiter" - - Returns: - NSA attention backend class - """ - # NSA currently only supports "None" dtype (no quantization) llm_dtype = "None" if backend_str not in nsa_data_type_to_backend[llm_dtype]: logger.warning(f"NSA backend '{backend_str}' not found, falling back to flashmla_sparse") diff --git a/lightllm/common/basemodel/attention/nsa/flashmla_sparse.py b/lightllm/common/basemodel/attention/nsa/flashmla_sparse.py index 8e52499998..3eec98f055 100644 --- a/lightllm/common/basemodel/attention/nsa/flashmla_sparse.py +++ b/lightllm/common/basemodel/attention/nsa/flashmla_sparse.py @@ -1,9 +1,3 @@ -"""NSA FlashMLA-sparse attention backend implementation. - -This backend uses sgl_kernel's flash_mla_sparse_fwd for prefill -and flash_attn_with_kvcache for decode with sparse indices. -""" - import dataclasses import torch from typing import Tuple, TYPE_CHECKING @@ -16,8 +10,6 @@ class NsaFlashMlaSparseAttBackend(BaseAttBackend): - """NSA backend using FlashMLA sparse kernels from sgl_kernel.""" - def __init__(self, model): super().__init__(model=model) @@ -47,20 +39,6 @@ def prefill_att( att_control: AttControl = AttControl(), alloc_func=torch.empty, ) -> torch.Tensor: - """Execute NSA prefill attention. - - Args: - q: Query tensor [total_tokens, num_heads, head_dim] - already projected with k_b_proj - k: KV buffer tensor from memory manager - v: Not used for NSA (pass None) - att_control: Must have nsa_prefill=True and nsa_prefill_dict with: - - topk_indices: Sparse attention indices [total_tokens, topk] - - softmax_scale: Attention softmax scale - - kv_lora_rank: d_v dimension for MLA - - Returns: - Output tensor [total_tokens, num_heads, kv_lora_rank] - """ assert att_control.nsa_prefill, "nsa_prefill must be True for NSA prefill attention" assert att_control.nsa_prefill_dict is not None, "nsa_prefill_dict is required" @@ -79,7 +57,6 @@ def _nsa_prefill_att( softmax_scale = nsa_dict["softmax_scale"] kv_lora_rank = nsa_dict["kv_lora_rank"] - # flash_mla_sparse_fwd expects indices with shape [total_tokens, 1, topk] if topk_indices.ndim == 2: topk_indices = topk_indices.unsqueeze(1) @@ -95,7 +72,6 @@ def _nsa_prefill_att( @dataclasses.dataclass class NsaFlashMlaSparseDecodeAttState(BaseDecodeAttState): - """Decode attention state for NSA using flash_attn_with_kvcache.""" cu_seqlens_q: torch.Tensor = None cu_seqlens_k: torch.Tensor = None @@ -112,23 +88,6 @@ def decode_att( att_control: AttControl = AttControl(), alloc_func=torch.empty, ) -> torch.Tensor: - """Execute NSA decode attention. - - Args: - q: Tuple of (q_nope, q_rope) tensors - k: KV buffer tensor from memory manager - v: Not used for NSA (pass None) - att_control: Must have nsa_decode=True and nsa_decode_dict with: - - topk_indices: Page table for sparse attention [batch, topk] - - nsa_cache_seqlens: Cache sequence lengths for NSA - - nsa_cu_seqlens_k: Cumulative sequence lengths for NSA - - softmax_scale: Attention softmax scale - - kv_lora_rank: d_v dimension for MLA - - qk_rope_head_dim: Rope head dimension - - Returns: - Output tensor - """ assert att_control.nsa_decode, "nsa_decode must be True for NSA decode attention" assert att_control.nsa_decode_dict is not None, "nsa_decode_dict is required" diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index c11b68c99c..5c1d2b8712 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -476,24 +476,6 @@ def _prefill( ) infer_state = self._create_inferstate(model_input) - - # Capture old indexer_ks positions before they are overwritten - # This is needed for DeepSeek v3.2 to copy cached tokens' indexer_ks - old_indexer_ks_positions = [] - for i in range(infer_state.b_req_idx.shape[0]): - req_idx = infer_state.b_req_idx[i].item() - ready_cache_len = infer_state.b_ready_cache_len[i].item() - - if ready_cache_len > 0: - # Capture old positions for cached tokens - old_pos = self.req_manager.req_to_token_indexs[ - req_idx, 0:ready_cache_len - ].clone() # Clone to avoid view issues - old_indexer_ks_positions.append(old_pos) - else: - # No cached tokens for this request - old_indexer_ks_positions.append(None) - init_req_to_token_indexes( req_to_token_indexs=self.req_manager.req_to_token_indexs, b_req_idx=infer_state.b_req_idx, @@ -502,8 +484,6 @@ def _prefill( b_start_loc=model_input.b_prefill_start_loc, alloc_mem_index=infer_state.mem_index, max_q_seq_len=infer_state.max_q_seq_len, - mem_manager=self.req_manager.mem_manager, - old_indexer_ks_positions=old_indexer_ks_positions, ) prefill_mem_indexes_ready_event = torch.cuda.Event() prefill_mem_indexes_ready_event.record() diff --git a/lightllm/common/infer_utils.py b/lightllm/common/infer_utils.py index ed3c0b73e4..26cf973be4 100644 --- a/lightllm/common/infer_utils.py +++ b/lightllm/common/infer_utils.py @@ -1,3 +1,4 @@ +import torch from lightllm.common.basemodel.triton_kernel.copy_kv_index_to_req import copy_kv_index_to_req_prefill @@ -9,10 +10,7 @@ def init_req_to_token_indexes( b_start_loc, alloc_mem_index, max_q_seq_len, - mem_manager=None, - old_indexer_ks_positions=None, ): - # Step 1: Copy KV cache for NEW tokens (existing logic) copy_kv_index_to_req_prefill( req_to_token_indexs=req_to_token_indexs, b_req_idx=b_req_idx, @@ -22,62 +20,3 @@ def init_req_to_token_indexes( memindex=alloc_mem_index, max_q_seq_len=max_q_seq_len, ) - - # Step 2: Copy indexer_ks for CACHED tokens (DeepSeek v3.2 specific) - # This ensures consistency between KV cache and indexer_ks buffers - # when prefix cache is hit - if ( - mem_manager is not None - and hasattr(mem_manager, "indexer_ks_mem_manager") - and old_indexer_ks_positions is not None - ): - - _copy_cached_indexer_ks_to_new_positions( - req_to_token_indexs=req_to_token_indexs, - b_req_idx=b_req_idx, - b_ready_cache_len=b_ready_cache_len, - mem_manager=mem_manager, - old_indexer_ks_positions=old_indexer_ks_positions, - ) - - -def _copy_cached_indexer_ks_to_new_positions( - req_to_token_indexs, - b_req_idx, - b_ready_cache_len, - mem_manager, - old_indexer_ks_positions, -): - """ - Copy cached tokens' indexer_ks from old positions to new positions. - - This function is called after copy_kv_index_to_req_prefill() has updated - req_to_token_indexs to point to new contiguous positions. We need to copy - indexer_ks data to match the KV cache layout. - - For each layer and each request with cached tokens: - - Copy indexer_ks data from old positions to new positions - - This ensures consistency when using extract_indexer_ks later - """ - from lightllm.models.deepseek3_2.triton_kernel.copy_indexer_ks import copy_indexer_ks - - # Get number of layers from indexer_ks_mem_manager - num_layers = len(mem_manager.indexer_ks_mem_manager.kv_buffer) - indexer_buffer = mem_manager.indexer_ks_mem_manager.kv_buffer - - for layer_idx in range(num_layers): - for i in range(b_req_idx.shape[0]): - req_idx = b_req_idx[i].item() - ready_cache_len = b_ready_cache_len[i].item() - old_positions = old_indexer_ks_positions[i] - - if ready_cache_len > 0 and old_positions is not None: - # New positions after copy_kv_index_to_req_prefill - new_positions = req_to_token_indexs[req_idx, 0:ready_cache_len] - - # Copy indexer_ks: old_positions -> new_positions - copy_indexer_ks( - buffer=indexer_buffer[layer_idx], - src_loc=old_positions, - dest_loc=new_positions, - ) diff --git a/lightllm/common/kv_cache_mem_manager/deepseek2_mem_manager.py b/lightllm/common/kv_cache_mem_manager/deepseek2_mem_manager.py index ad54b39353..3d93e1b070 100644 --- a/lightllm/common/kv_cache_mem_manager/deepseek2_mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/deepseek2_mem_manager.py @@ -15,8 +15,8 @@ class Deepseek2MemoryManager(MemoryManager): - def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9, is_sub_mem_manager=False): - super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy, mem_fraction, is_sub_mem_manager) + def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9): + super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy, mem_fraction) def copy_kv_to_mem_manager(self, layer_index: int, mem_index: torch.Tensor, kv: torch.Tensor): """ diff --git a/lightllm/common/kv_cache_mem_manager/mem_manager.py b/lightllm/common/kv_cache_mem_manager/mem_manager.py index 7d5e2af046..1203cbdec7 100755 --- a/lightllm/common/kv_cache_mem_manager/mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/mem_manager.py @@ -26,9 +26,7 @@ class MemoryManager: - def __init__( - self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9, is_sub_mem_manager=False - ): + def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9): self.size = size self.head_num = head_num self.head_dim = head_dim @@ -50,16 +48,15 @@ def __init__( self.can_use_mem_size = self.size - if not is_sub_mem_manager: - # 用共享内存进行共享,router 模块读取进行精确的调度估计, nccl port 作为一个单机中单实列的标记。防止冲突。 - from lightllm.utils.envs_utils import get_unique_server_name + # 用共享内存进行共享,router 模块读取进行精确的调度估计, nccl port 作为一个单机中单实列的标记。防止冲突。 + from lightllm.utils.envs_utils import get_unique_server_name - rank_in_node = get_current_rank_in_node() - self.shared_can_use_token_num = SharedInt( - f"{get_unique_server_name()}_mem_manger_can_use_token_num_{rank_in_node}" - ) + rank_in_node = get_current_rank_in_node() + self.shared_can_use_token_num = SharedInt( + f"{get_unique_server_name()}_mem_manger_can_use_token_num_{rank_in_node}" + ) - self.shared_can_use_token_num.set_value(self.can_use_mem_size) + self.shared_can_use_token_num.set_value(self.can_use_mem_size) self._init_buffers( self.size, dtype, @@ -95,17 +92,6 @@ def profile_size(self, mem_fraction): available_memory = get_available_gpu_memory(world_size) - total_memory * (1 - mem_fraction) cell_size = self.get_cell_size() self.size = int(available_memory * 1024 ** 3 / cell_size) - - # Ensure size is at least a minimum positive value to avoid torch.arange errors - MIN_SIZE = 1024 # Minimum 1024 tokens - if self.size < MIN_SIZE: - logger.warning( - f"Insufficient memory for KV cache. Available: {available_memory:.2f} GB, " - f"but calculated size is {self.size} tokens. Using minimum size {MIN_SIZE} tokens instead. " - f"Consider reducing model size, using fewer GPUs, or increasing mem_fraction." - ) - self.size = MIN_SIZE - if world_size > 1: tensor = torch.tensor(self.size, dtype=torch.int64, device=f"cuda:{get_current_device_id()}") dist.all_reduce(tensor, op=dist.ReduceOp.MIN) diff --git a/lightllm/models/deepseek3_2/infer_struct.py b/lightllm/models/deepseek3_2/infer_struct.py index 2f8aa75629..e0cca499bd 100644 --- a/lightllm/models/deepseek3_2/infer_struct.py +++ b/lightllm/models/deepseek3_2/infer_struct.py @@ -1,10 +1,10 @@ import torch import weakref -from lightllm.models.deepseek2.flashattention_infer_struct import Deepseek2FlashAttentionStateInfo +from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo from lightllm.models.deepseek3_2.mem_manager import Deepseek3_2MemoryManager -class Deepseek3_2FlashAttentionStateInfo(Deepseek2FlashAttentionStateInfo): +class Deepseek3_2FlashAttentionStateInfo(Deepseek2InferStateInfo): _shared_nsa_buffers = None def __init__(self): @@ -21,7 +21,6 @@ def __init__(self): def get_nsa_buffers(cls, graph_max_batch_size: int, max_seq_len: int): """Get or create pre-allocated buffers for CUDA graph execution""" if cls._shared_nsa_buffers is None: - # Pre-allocate buffers for max possible sizes max_total_q_tokens = graph_max_batch_size * max_seq_len max_total_tokens = graph_max_batch_size * max_seq_len @@ -35,7 +34,7 @@ def get_nsa_buffers(cls, graph_max_batch_size: int, max_seq_len: int): "nsa_cache_seqlens": torch.empty(graph_max_batch_size, dtype=torch.int32, device="cuda"), "nsa_cu_seqlens_k": torch.empty(graph_max_batch_size + 1, dtype=torch.int32, device="cuda"), }, - { # Second buffer for microbatch overlap if needed + { "ks": torch.empty(max_total_q_tokens, dtype=torch.int, device="cuda"), "ke": torch.empty(max_total_q_tokens, dtype=torch.int, device="cuda"), "lengths": torch.empty(max_total_q_tokens, dtype=torch.int, device="cuda"), @@ -47,155 +46,124 @@ def get_nsa_buffers(cls, graph_max_batch_size: int, max_seq_len: int): ] return cls._shared_nsa_buffers - def init_some_extra_state(self, model, input_ids: torch.Tensor): - super().init_some_extra_state(model, input_ids) + def _check_use_cuda_graph_buffers(self): + if hasattr(self, "_model_ref"): + model = self._model_ref() + if ( + model is not None + and hasattr(model, "graph_max_batch_size") + and hasattr(model, "graph_max_len_in_batch") + and self.batch_size <= model.graph_max_batch_size + and self.max_len_in_batch <= model.graph_max_len_in_batch + ): + return True + return False + + def init_some_extra_state(self, model): + super().init_some_extra_state(model) - # Store weak reference to model for accessing graph parameters self._model_ref = weakref.ref(model) assert isinstance(self.mem_manager, Deepseek3_2MemoryManager) - self.indexer_ks_mem_manager = self.mem_manager.indexer_ks_mem_manager + self.indexer_ks_buffer = self.mem_manager.indexer_ks_buffer - # Ensure b_ready_cache_len is set for both prefill and decode modes if self.is_prefill: - # b_ready_cache_len is already set in basemodel.py for prefill pass else: - # In decode mode, b_ready_cache_len is set by the router/scheduler - # based on actual prefix cache hits. If it's None (no prefix cache enabled), - # it should be 0, not computed from b_seq_len - b_q_seq_len if self.b_ready_cache_len is None: self.b_ready_cache_len = torch.zeros_like(self.b_seq_len) - # Check if we can use CUDA graph based on batch size and max_len constraints - use_cuda_graph_buffers = False - if ( - hasattr(model, "graph_max_batch_size") - and hasattr(model, "graph_max_len_in_batch") - and self.batch_size <= model.graph_max_batch_size - and self.max_len_in_batch <= model.graph_max_len_in_batch - ): - use_cuda_graph_buffers = True + use_cuda_graph_buffers = self._check_use_cuda_graph_buffers() - # Setup nsa_cache_seqlens and nsa_cu_seqlens_k with pre-allocated buffers if using CUDA graph if use_cuda_graph_buffers: buffers = self.get_nsa_buffers(model.graph_max_batch_size, model.graph_max_len_in_batch) buffer = buffers[self.microbatch_index] - - # Use views into pre-allocated buffers self.nsa_cache_seqlens = buffer["nsa_cache_seqlens"][: self.batch_size] self.nsa_cu_seqlens_k = buffer["nsa_cu_seqlens_k"][: self.batch_size + 1] else: - # Create new tensors dynamically self.nsa_cache_seqlens = torch.empty(self.batch_size, dtype=torch.int32, device="cuda") self.nsa_cu_seqlens_k = torch.empty(self.batch_size + 1, dtype=torch.int32, device="cuda") - # Calculate actual values self.nsa_cache_seqlens.copy_(self.b_att_seq_len.clamp(max=self.index_topk)) assert self.nsa_cache_seqlens.dtype == torch.int32 - # Compute cumulative sum with padding torch.cumsum(self.nsa_cache_seqlens, dim=0, dtype=torch.int32, out=self.nsa_cu_seqlens_k[1:]) self.nsa_cu_seqlens_k[0] = 0 - # Pre-compute NSA indexer indexing structures self._init_nsa_indexing_structures() def _init_nsa_indexing_structures(self): - """Pre-compute ks, ke, lengths, and page_table_size_1 for NSA indexer""" - req_all_mem_index_list = [] - ks_list = [] - ke_list = [] - lengths_list = [] - offset = 0 - num_seq_len = self.b_req_idx.shape[0] - max_seq_len = self.b_seq_len.max().item() - - # Calculate total sizes needed - total_q_len = sum(self.b_q_seq_len[i].item() for i in range(num_seq_len)) - total_seq_len = sum(self.b_seq_len[i].item() for i in range(num_seq_len)) - - # Check if we should use CUDA graph buffers - use_cuda_graph_buffers = False - if hasattr(self, "_model_ref"): - model = self._model_ref() - if ( - model is not None - and hasattr(model, "graph_max_batch_size") - and hasattr(model, "graph_max_len_in_batch") - and self.batch_size <= model.graph_max_batch_size - and self.max_len_in_batch <= model.graph_max_len_in_batch - ): - use_cuda_graph_buffers = True + """Pre-compute ks, ke, lengths, and page_table_size_1 for NSA indexer. + + Fully vectorized: eliminates per-request .item() CPU-GPU syncs. + """ + b_seq_len = self.b_seq_len + b_q_seq_len = self.b_q_seq_len + b_req_idx = self.b_req_idx + num_seq = b_req_idx.shape[0] + device = b_seq_len.device + + # Only 3 scalar syncs needed (for tensor shapes) + max_seq_len = b_seq_len.max().item() + total_q_len = b_q_seq_len.sum().item() + total_seq_len = b_seq_len.sum().item() + + # --- page_table_size_1 and req_all_mem_index (vectorized gather) --- + all_rows = self.req_manager.req_to_token_indexs[b_req_idx, :max_seq_len] + seq_range = torch.arange(max_seq_len, device=device) + valid_mask = seq_range.unsqueeze(0) < b_seq_len.unsqueeze(1) + + # page_table_size_1: [batch, max_seq_len] zero-padded memory indices + page_table = torch.zeros((num_seq, max_seq_len), dtype=torch.int, device=device) + page_table[valid_mask] = all_rows[valid_mask].int() + + # req_all_mem_index: flattened valid memory indices across all requests + req_all_mem_index = all_rows[valid_mask] + + # --- ks, ke, lengths (vectorized computation) --- + # Cumulative seq_len offsets: [0, seq_len[0], seq_len[0]+seq_len[1], ...] + cum_seq = torch.cumsum(b_seq_len, dim=0) + seq_offsets = torch.zeros_like(cum_seq) + seq_offsets[1:] = cum_seq[:-1] + + # Expand per-request values to per-token using repeat_interleave + req_indices = torch.repeat_interleave(torch.arange(num_seq, device=device), b_q_seq_len) + + # Token position within each request's q_seq + cum_q = torch.cumsum(b_q_seq_len, dim=0) + q_offsets = torch.zeros_like(cum_q) + q_offsets[1:] = cum_q[:-1] + token_in_req = torch.arange(total_q_len, device=device) - q_offsets[req_indices] + + # ks[t] = seq_offset of request owning token t + # ke[t] = seq_offset + position_in_q + 1 + # lengths[t] = seq_len - q_seq_len + position_in_q + 1 + ks = seq_offsets[req_indices].int() + ke = (seq_offsets[req_indices] + token_in_req + 1).int() + lengths = (b_seq_len[req_indices] - b_q_seq_len[req_indices] + token_in_req + 1).int() + + # --- Assign results (CUDA graph buffer or new tensors) --- + use_cuda_graph_buffers = self._check_use_cuda_graph_buffers() if use_cuda_graph_buffers: - # Use pre-allocated buffers for CUDA graph model = self._model_ref() buffers = self.get_nsa_buffers(model.graph_max_batch_size, model.graph_max_len_in_batch) buffer = buffers[self.microbatch_index] - # Use views into pre-allocated buffers self.ks = buffer["ks"][:total_q_len] self.ke = buffer["ke"][:total_q_len] self.lengths = buffer["lengths"][:total_q_len] - self.page_table_size_1 = buffer["page_table_size_1"][:num_seq_len, :max_seq_len] + self.page_table_size_1 = buffer["page_table_size_1"][:num_seq, :max_seq_len] self.req_all_mem_index = buffer["req_all_mem_index"][:total_seq_len] - # Zero out page_table_size_1 before filling - self.page_table_size_1.zero_() - - # Compute and copy values into the pre-allocated buffer views - ks_offset = 0 - ke_offset = 0 - lengths_offset = 0 - req_offset = 0 - seq_offset = 0 - - for i in range(num_seq_len): - seq_len = self.b_seq_len[i].item() - q_seq_len = self.b_q_seq_len[i].item() - req_idx = self.b_req_idx[i].item() - mem_index = self.req_manager.req_to_token_indexs[req_idx, :seq_len] - - # Copy req_all_mem_index - self.req_all_mem_index[req_offset : req_offset + seq_len] = mem_index - - # Fill page_table_size_1 - self.page_table_size_1[i, :seq_len] = mem_index - - # Fill ks, ke, lengths - self.ks[ks_offset : ks_offset + q_seq_len].fill_(seq_offset) - self.ke[ke_offset : ke_offset + q_seq_len] = torch.arange( - seq_offset + 1, seq_offset + q_seq_len + 1, dtype=torch.int, device="cuda" - ) - self.lengths[lengths_offset : lengths_offset + q_seq_len] = torch.arange( - seq_len - q_seq_len + 1, seq_len + 1, dtype=torch.int, device="cuda" - ) - - ks_offset += q_seq_len - ke_offset += q_seq_len - lengths_offset += q_seq_len - req_offset += seq_len - seq_offset += seq_len + self.ks.copy_(ks) + self.ke.copy_(ke) + self.lengths.copy_(lengths) + self.page_table_size_1.copy_(page_table) + self.req_all_mem_index.copy_(req_all_mem_index) else: - # Original dynamic allocation for non-CUDA graph mode - self.page_table_size_1 = torch.zeros((num_seq_len, max_seq_len), dtype=torch.int, device="cuda") - - for i in range(num_seq_len): - seq_len = self.b_seq_len[i].item() - q_seq_len = self.b_q_seq_len[i].item() - req_idx = self.b_req_idx[i].item() - mem_index = self.req_manager.req_to_token_indexs[req_idx, :seq_len] - req_all_mem_index_list.append(mem_index) - self.page_table_size_1[i, :seq_len] = mem_index - ks = torch.zeros(q_seq_len, dtype=torch.int, device="cuda") + offset - ke = torch.arange(q_seq_len, dtype=torch.int, device="cuda") + offset + 1 - ks_list.append(ks) - ke_list.append(ke) - lengths_list.append(torch.arange(seq_len - q_seq_len + 1, seq_len + 1, dtype=torch.int, device="cuda")) - offset += seq_len - - self.req_all_mem_index = torch.cat(req_all_mem_index_list, dim=0) - self.ks = torch.cat(ks_list, dim=0) - self.ke = torch.cat(ke_list, dim=0) - self.lengths = torch.cat(lengths_list, dim=0) + self.ks = ks + self.ke = ke + self.lengths = lengths + self.page_table_size_1 = page_table + self.req_all_mem_index = req_all_mem_index diff --git a/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py b/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py index 2f4421e742..3855bf590f 100644 --- a/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py +++ b/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py @@ -8,10 +8,8 @@ from lightllm.models.deepseek3_2.infer_struct import Deepseek3_2FlashAttentionStateInfo from lightllm.models.deepseek2.triton_kernel.rotary_emb import rotary_emb_fwd from lightllm.models.deepseek3_2.triton_kernel.act_quant import act_quant -from lightllm.models.deepseek3_2.mem_manager import Deepseek3_2MemoryManager from lightllm.models.deepseek3_2.triton_kernel.destindex_copy_indexer_ks import destindex_copy_indexer_ks from lightllm.models.deepseek3_2.triton_kernel.extract_indexer_ks import extract_indexer_ks -from lightllm.models.bloom.triton_kernel.layernorm import layernorm_forward from lightllm.utils.log_utils import init_logger logger = init_logger(__name__) @@ -84,7 +82,7 @@ def get_indices( k_fp8, k_scale = act_quant(k, self.block_size, self.scale_fmt) destindex_copy_indexer_ks( - k_fp8, k_scale, infer_state.mem_index, infer_state.indexer_ks_mem_manager.kv_buffer[self.layer_idx_] + k_fp8, k_scale, infer_state.mem_index, infer_state.indexer_ks_buffer.kv_buffer[self.layer_idx_] ) weights = layer_weight.weights_proj_.mm(hidden_states) * self.index_n_heads_scale @@ -97,7 +95,7 @@ def get_indices( # Use efficient Triton kernel to extract FP8 keys and scales from buffer k_fp8_, k_scale_ = extract_indexer_ks( - infer_state.indexer_ks_mem_manager.kv_buffer[self.layer_idx_], infer_state.req_all_mem_index + infer_state.indexer_ks_buffer.kv_buffer[self.layer_idx_], infer_state.req_all_mem_index ) # Get actual sequence length from q (which comes from q_lora) diff --git a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py index 1abb749a8e..bc8bb9c6ba 100644 --- a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py @@ -1,4 +1,3 @@ -from functools import partial from typing import override import torch @@ -8,7 +7,7 @@ from lightllm.models.deepseek3_2.layer_weights.transformer_layer_weight import Deepseek3_2TransformerLayerWeight from lightllm.models.deepseek3_2.infer_struct import Deepseek3_2FlashAttentionStateInfo from lightllm.models.deepseek3_2.triton_kernel.token_group_quant import per_token_group_quant_mla_deep_gemm_masked_fp8 -from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward +from lightllm.common.basemodel.triton_kernel.norm.rmsnorm import rmsnorm_forward from lightllm.models.deepseek2.triton_kernel.rotary_emb import rotary_emb_fwd from lightllm.common.basemodel.attention.base_att import AttControl from lightllm.common.basemodel.attention.create_utils import get_nsa_prefill_att_backend_class @@ -73,17 +72,7 @@ def _get_qkv( return q, cache_kv @override - def _bind_attention(self): - if "triton_fp8kv" in self.mode: - self._copy_kv_to_mem_cache = partial(Deepseek2TransformerLayerInfer._copy_kv_to_mem_cache_fp8, self) - else: - self._copy_kv_to_mem_cache = partial(Deepseek2TransformerLayerInfer._copy_kv_to_mem_cache_normal, self) - - self._context_attention_kernel = partial(Deepseek3_2TransformerLayerInfer._nsa_context_attention_kernel, self) - self._token_attention_kernel = partial(Deepseek3_2TransformerLayerInfer._nsa_token_attention_kernel, self) - pass - - def _nsa_context_attention_kernel( + def _context_attention_kernel( self, q: torch.Tensor, kv, @@ -118,7 +107,8 @@ def _nsa_context_attention_kernel( ) return mla_out - def _nsa_token_attention_kernel( + @override + def _token_attention_kernel( self, q, infer_state: Deepseek3_2FlashAttentionStateInfo, diff --git a/lightllm/models/deepseek3_2/layer_weights/nsa_indexer_layer_weight.py b/lightllm/models/deepseek3_2/layer_weights/nsa_indexer_layer_weight.py index 47e0bfdac5..9ccfbe97ed 100644 --- a/lightllm/models/deepseek3_2/layer_weights/nsa_indexer_layer_weight.py +++ b/lightllm/models/deepseek3_2/layer_weights/nsa_indexer_layer_weight.py @@ -3,7 +3,7 @@ import torch from lightllm.common.basemodel.layer_weights.transformer_layer_weight import TransformerLayerWeight -from lightllm.common.basemodel.layer_weights.meta_weights import ROWMMWeight, NormWeight +from lightllm.common.basemodel.layer_weights.meta_weights import ROWMMWeight, LayerNormWeight class NSAIndexerWeight(TransformerLayerWeight): @@ -14,7 +14,7 @@ def __init__(self, layer_num, data_type, network_config, mode, quant_cfg): @override def _init_weight(self): prefix = f"model.layers.{self.layer_num_}.self_attn.indexer" - + self.wq_b_proj_ = ROWMMWeight( weight_name=f"{prefix}.wq_b.weight", data_type=self.data_type_, @@ -33,15 +33,16 @@ def _init_weight(self): tp_rank=0, tp_world_size=1, ) - self.k_norm_ = NormWeight( - f"{prefix}.k_norm.weight", - torch.float32, - bias_name=f"{prefix}.k_norm.bias" + self.k_norm_ = LayerNormWeight( + dim=self.network_config_["index_head_dim"], + weight_name=f"{prefix}.k_norm.weight", + data_type=torch.float32, + bias_name=f"{prefix}.k_norm.bias", ) self.weights_proj_ = ROWMMWeight( weight_name=f"{prefix}.weights_proj.weight", data_type=self.data_type_, - quant_cfg=None, + quant_cfg=None, layer_num=self.layer_num_, name="weights_proj", tp_rank=0, diff --git a/lightllm/models/deepseek3_2/mem_manager.py b/lightllm/models/deepseek3_2/mem_manager.py index a70c762731..8017a84adc 100644 --- a/lightllm/models/deepseek3_2/mem_manager.py +++ b/lightllm/models/deepseek3_2/mem_manager.py @@ -1,22 +1,41 @@ -from typing import List from typing_extensions import override import torch -from lightllm.common.mem_manager import MemoryManager -from lightllm.common.deepseek2_mem_manager import Deepseek2MemoryManager -from lightllm.server.pd_io_struct import KVMoveTask -from lightllm.distributed.pynccl import PyNcclCommunicator +from lightllm.common.kv_cache_mem_manager.mem_manager import MemoryManager +from lightllm.common.kv_cache_mem_manager.deepseek2_mem_manager import Deepseek2MemoryManager + + +class IndexerKSBuffer: + """Lightweight buffer holder for NSA indexer keys+scales. + + Shares token indices with the parent MemoryManager — does NOT have its + own allocator. Only stores the per-layer kv_buffer tensor. + """ + + def __init__(self, size: int, head_num: int, head_dim: int, layer_num: int, dtype=torch.uint8): + self.kv_buffer = torch.empty((layer_num, size + 1, head_num, head_dim), dtype=dtype, device="cuda") + class Deepseek3_2MemoryManager(Deepseek2MemoryManager): - def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9 ,is_sub_mem_manager=False): - super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy, mem_fraction, is_sub_mem_manager) - self.indexer_ks_mem_manager = Deepseek2MemoryManager(self.size, torch.uint8, 1, 132, layer_num, is_sub_mem_manager=True) - return + def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9): + super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy, mem_fraction) + self.indexer_ks_buffer = IndexerKSBuffer(self.size, 1, 132, layer_num) @override def get_cell_size(self): return super().get_cell_size() + 132 - + + @override + def _free_buffers(self): + super()._free_buffers() + self.indexer_ks_buffer = None + + @override + def resize_mem(self, new_size): + super().resize_mem(new_size) + self.indexer_ks_buffer = IndexerKSBuffer(self.size, 1, 132, self.layer_num) + + class Deepseek3_2FP8KVMemoryManager(Deepseek3_2MemoryManager): - def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9, is_sub_mem_manager=False): - super().__init__(size, torch.uint8, head_num, head_dim + 2, layer_num, always_copy, mem_fraction, is_sub_mem_manager) \ No newline at end of file + def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9): + super().__init__(size, torch.uint8, head_num, head_dim + 2, layer_num, always_copy, mem_fraction) diff --git a/lightllm/models/deepseek3_2/model.py b/lightllm/models/deepseek3_2/model.py index 8f1ba85cf2..d25cbd3783 100644 --- a/lightllm/models/deepseek3_2/model.py +++ b/lightllm/models/deepseek3_2/model.py @@ -5,6 +5,8 @@ from lightllm.utils.envs_utils import get_env_start_args from lightllm.models.deepseek3_2.infer_struct import Deepseek3_2FlashAttentionStateInfo from lightllm.models.deepseek3_2.mem_manager import Deepseek3_2MemoryManager, Deepseek3_2FP8KVMemoryManager + + @ModelRegistry(["deepseek_v32"]) class Deepseek3_2TpPartModel(Deepseek2TpPartModel): # weight class @@ -44,4 +46,4 @@ def _init_mem_manager(self): layer_num=self.config["num_hidden_layers"] + added_mtp_layer_num, mem_fraction=self.mem_fraction, ) - return \ No newline at end of file + return diff --git a/lightllm/models/deepseek3_2/triton_kernel/copy_indexer_ks.py b/lightllm/models/deepseek3_2/triton_kernel/copy_indexer_ks.py deleted file mode 100644 index 93cf463eb0..0000000000 --- a/lightllm/models/deepseek3_2/triton_kernel/copy_indexer_ks.py +++ /dev/null @@ -1,232 +0,0 @@ -import torch - -import triton -import triton.language as tl - - -@triton.jit -def _fwd_kernel_copy_indexer_ks( - buffer, # [large_size, 1, 132] uint8 - src_loc, # [copy_len] int32/int64 - source indices - dest_loc, # [copy_len] int32/int64 - destination indices - stride_bs, - stride_h, - stride_d, - BLOCK_KV: tl.constexpr, # = 128 (FP8 data) - BLOCK_SCALE: tl.constexpr, # = 4 (scale data) -): - """ - Triton kernel to copy indexer_ks data from source locations to destination locations. - - This kernel copies 132-byte indexer_ks entries (128 bytes FP8 key + 4 bytes float32 scale) - from source positions to destination positions within the same buffer. - - Args: - buffer: Shared buffer containing indexer_ks data [large_size, 1, 132] uint8 - src_loc: Source indices to copy from [copy_len] - dest_loc: Destination indices to copy to [copy_len] - stride_bs, stride_h, stride_d: Strides for the buffer - BLOCK_KV: Size of FP8 key data (128 bytes) - BLOCK_SCALE: Size of scale data (4 bytes) - """ - cur_index = tl.program_id(0) - offs_kv = tl.arange(0, BLOCK_KV) - offs_scale = tl.arange(0, BLOCK_SCALE) - - # Load source and destination indices - src_index = tl.load(src_loc + cur_index).to(tl.int64) - dest_index = tl.load(dest_loc + cur_index).to(tl.int64) - - # Copy FP8 key data (128 bytes) - src_kv_ptrs = buffer + src_index * stride_bs + stride_d * offs_kv - dest_kv_ptrs = buffer + dest_index * stride_bs + stride_d * offs_kv - kv_data = tl.load(src_kv_ptrs) - tl.store(dest_kv_ptrs, kv_data) - - # Copy scale data (4 bytes at offset 128) - src_scale_base = buffer + src_index * stride_bs + BLOCK_KV * stride_d - dest_scale_base = buffer + dest_index * stride_bs + BLOCK_KV * stride_d - scale_data = tl.load(src_scale_base + offs_scale * stride_d) - tl.store(dest_scale_base + offs_scale * stride_d, scale_data) - - return - - -@torch.no_grad() -def copy_indexer_ks( - buffer: torch.Tensor, - src_loc: torch.Tensor, - dest_loc: torch.Tensor, -): - """ - Copy indexer_ks data from source positions to destination positions. - - This function is used to copy cached tokens' indexer_ks data to new locations - after prefix cache matching. It ensures that the indexer_ks buffer stays - consistent with the KV cache buffer. - - Args: - buffer: [large_size, 1, 132] torch.uint8 - Buffer containing indexer_ks data (same buffer for src and dest) - src_loc: [copy_len] torch.int32 or torch.int64 - Source indices in buffer (old positions) - dest_loc: [copy_len] torch.int32 or torch.int64 - Destination indices in buffer (new positions) - - Returns: - None (modifies buffer in-place) - - Example: - >>> buffer = torch.zeros((1024, 1, 132), dtype=torch.uint8).cuda() - >>> old_pos = torch.tensor([100, 101, 102], dtype=torch.int32).cuda() - >>> new_pos = torch.tensor([200, 201, 202], dtype=torch.int32).cuda() - >>> copy_indexer_ks(buffer, old_pos, new_pos) - # Data from positions [100, 101, 102] is now copied to [200, 201, 202] - """ - copy_len = src_loc.shape[0] - block_kv = 128 # FP8 key data size - block_scale = 4 # Float32 scale size - - assert ( - src_loc.shape[0] == dest_loc.shape[0] - ), f"src_loc and dest_loc must have same length: {src_loc.shape[0]} != {dest_loc.shape[0]}" - assert ( - buffer.shape[2] == block_kv + block_scale - ), f"Expected buffer last dim={block_kv + block_scale}, got {buffer.shape[2]}" - assert buffer.dtype == torch.uint8, f"Expected buffer dtype=uint8, got {buffer.dtype}" - - grid = (copy_len,) - num_warps = 1 - - _fwd_kernel_copy_indexer_ks[grid]( - buffer, - src_loc, - dest_loc, - buffer.stride(0), - buffer.stride(1), - buffer.stride(2), - BLOCK_KV=block_kv, - BLOCK_SCALE=block_scale, - num_warps=num_warps, - num_stages=1, - ) - - return - - -def test_copy_indexer_ks(): - """Test the copy_indexer_ks kernel""" - import torch.nn.functional as F - from lightllm.models.deepseek3_2.triton_kernel.destindex_copy_indexer_ks import destindex_copy_indexer_ks - from lightllm.models.deepseek3_2.triton_kernel.extract_indexer_ks import extract_indexer_ks - - print("=" * 80) - print("Testing copy_indexer_ks") - print("=" * 80) - - # Test parameters - cached_len = 20 - buffer_size = 1024 - head_dim = 128 - dtype = torch.bfloat16 - fp8_type = torch.float8_e4m3fn - - # Create indexer_ks data - k_bf16 = torch.randn((cached_len, head_dim), dtype=dtype, device="cuda") - - # Quantize to FP8 - k_abs_max = k_bf16.abs().max(dim=1, keepdim=True)[0].clamp(min=1e-12) - k_scale = (k_abs_max / torch.finfo(fp8_type).max).to(torch.float32) - k_fp8 = (k_bf16 / k_abs_max).clamp(torch.finfo(fp8_type).min, torch.finfo(fp8_type).max).to(fp8_type) - - # Write to old positions - old_positions = torch.arange(100, 100 + cached_len, dtype=torch.int32, device="cuda") - buffer = torch.zeros((buffer_size, 1, 132), dtype=torch.uint8, device="cuda") - destindex_copy_indexer_ks(k_fp8, k_scale, old_positions, buffer) - - # Copy to new positions - new_positions = torch.arange(200, 200 + cached_len, dtype=torch.int32, device="cuda") - copy_indexer_ks(buffer, old_positions, new_positions) - - # Verify data at new positions matches original - k_fp8_extracted, k_scale_extracted = extract_indexer_ks(buffer, new_positions) - - fp8_match = torch.allclose(k_fp8_extracted.to(torch.float32), k_fp8.to(torch.float32), atol=0, rtol=0) - - scale_match = torch.allclose(k_scale_extracted, k_scale.squeeze(-1), atol=1e-6, rtol=1e-5) - - # Check dequantized values - k_dequant_extracted = k_fp8_extracted.to(dtype) * k_scale_extracted.unsqueeze(-1) - cosine_sim = F.cosine_similarity(k_dequant_extracted, k_bf16, dim=-1).mean() - - print(f"Cached tokens: {cached_len}, Head dim: {head_dim}") - print(f" FP8 values match: {fp8_match}") - print(f" Scale values match: {scale_match}") - print(f" Cosine similarity after dequantization: {cosine_sim:.6f}") - - assert fp8_match, "FP8 values do not match!" - assert scale_match, "Scale values do not match!" - assert cosine_sim > 0.99, f"Cosine similarity too low: {cosine_sim}" - - print("✓ Basic test passed!") - print() - - # Test with sequential indices - print("Testing sequential indices...") - old_pos_seq = torch.arange(20, dtype=torch.int32, device="cuda") - new_pos_seq = torch.arange(200, 220, dtype=torch.int32, device="cuda") - - k_bf16_seq = torch.randn((20, head_dim), dtype=dtype, device="cuda") - k_abs_max_seq = k_bf16_seq.abs().max(dim=1, keepdim=True)[0].clamp(min=1e-12) - k_scale_seq = (k_abs_max_seq / torch.finfo(fp8_type).max).to(torch.float32) - k_fp8_seq = (k_bf16_seq / k_abs_max_seq).clamp(torch.finfo(fp8_type).min, torch.finfo(fp8_type).max).to(fp8_type) - - buffer_seq = torch.zeros((buffer_size, 1, 132), dtype=torch.uint8, device="cuda") - destindex_copy_indexer_ks(k_fp8_seq, k_scale_seq, old_pos_seq, buffer_seq) - copy_indexer_ks(buffer_seq, old_pos_seq, new_pos_seq) - - k_fp8_ext_seq, k_scale_ext_seq = extract_indexer_ks(buffer_seq, new_pos_seq) - - fp8_match_seq = torch.allclose(k_fp8_ext_seq.to(torch.float32), k_fp8_seq.to(torch.float32), atol=0, rtol=0) - scale_match_seq = torch.allclose(k_scale_ext_seq, k_scale_seq.squeeze(-1), atol=1e-6, rtol=1e-5) - - print(f" Sequential indices: FP8={fp8_match_seq}, Scale={scale_match_seq}") - assert fp8_match_seq and scale_match_seq - print("✓ Sequential test passed!") - print() - - # Test with single element - print("Testing single element...") - old_pos_single = torch.tensor([42], dtype=torch.int32, device="cuda") - new_pos_single = torch.tensor([424], dtype=torch.int32, device="cuda") - - k_bf16_single = torch.randn((1, head_dim), dtype=dtype, device="cuda") - k_abs_max_single = k_bf16_single.abs().max(dim=1, keepdim=True)[0].clamp(min=1e-12) - k_scale_single = (k_abs_max_single / torch.finfo(fp8_type).max).to(torch.float32) - k_fp8_single = ( - (k_bf16_single / k_abs_max_single).clamp(torch.finfo(fp8_type).min, torch.finfo(fp8_type).max).to(fp8_type) - ) - - buffer_single = torch.zeros((buffer_size, 1, 132), dtype=torch.uint8, device="cuda") - destindex_copy_indexer_ks(k_fp8_single, k_scale_single, old_pos_single, buffer_single) - copy_indexer_ks(buffer_single, old_pos_single, new_pos_single) - - k_fp8_ext_single, k_scale_ext_single = extract_indexer_ks(buffer_single, new_pos_single) - - fp8_match_single = torch.allclose( - k_fp8_ext_single.to(torch.float32), k_fp8_single.to(torch.float32), atol=0, rtol=0 - ) - scale_match_single = torch.allclose(k_scale_ext_single, k_scale_single.squeeze(-1), atol=1e-6, rtol=1e-5) - - print(f" Single element: FP8={fp8_match_single}, Scale={scale_match_single}") - assert fp8_match_single and scale_match_single - print("✓ Single element test passed!") - print() - - print("=" * 80) - print("All tests passed successfully! ✓") - print("=" * 80) - - -if __name__ == "__main__": - test_copy_indexer_ks() diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 96126744af..4b92298b6b 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -128,7 +128,7 @@ def make_argument_parser() -> argparse.ArgumentParser: parser.add_argument( "--tool_call_parser", type=str, - choices=["qwen25", "llama3", "mistral", "deepseekv3", "qwen", "deepseekv31", "glm47", "kimi_k2"], + choices=["qwen25", "llama3", "mistral", "deepseekv3", "qwen", "deepseekv31", "deepseekv32", "glm47", "kimi_k2"], default=None, help="tool call parser type", ) diff --git a/lightllm/server/api_openai.py b/lightllm/server/api_openai.py index 11e24612b0..b65f7e9cc5 100644 --- a/lightllm/server/api_openai.py +++ b/lightllm/server/api_openai.py @@ -464,6 +464,39 @@ async def stream_results() -> AsyncGenerator[bytes, None]: yield ("data: " + json.dumps(stream_resp.dict(), ensure_ascii=False) + "\n\n").encode("utf-8") # Additional usage chunk + # Finalize any pending tool calls (e.g., DSML format last invoke) + if request.tool_choice != "none" and request.tools and parser_dict: + for _idx, _parser in parser_dict.items(): + _, finalize_calls = _parser.finalize_stream() + history_tool_calls_cnt = _get_history_tool_calls_cnt(request) + for call_item in finalize_calls: + if call_item.name: + tool_call_id = _process_tool_call_id(tool_parser, call_item, history_tool_calls_cnt) + function_name = call_item.name + else: + tool_call_id = None + function_name = None + tool_call = ToolCall( + id=tool_call_id, + index=getattr(call_item, "tool_index", None), + function=FunctionResponse( + name=function_name, + arguments=call_item.parameters, + ), + ) + choice_data = ChatCompletionStreamResponseChoice( + index=0, + delta=DeltaMessage(role="assistant", tool_calls=[tool_call]), + finish_reason="tool_calls", + ) + chunk = ChatCompletionStreamResponse( + id=group_request_id, + created=created_time, + choices=[choice_data], + model=request.model, + ) + yield f"data: {chunk.model_dump_json()}\n\n" + if request.stream_options and request.stream_options.include_usage: usage = UsageInfo( prompt_tokens=prompt_tokens, diff --git a/lightllm/server/core/objs/sampling_params.py b/lightllm/server/core/objs/sampling_params.py index b9ad314dd9..f2c5177e41 100644 --- a/lightllm/server/core/objs/sampling_params.py +++ b/lightllm/server/core/objs/sampling_params.py @@ -332,27 +332,34 @@ class SamplingParams(ctypes.Structure): _top_p: float = 1.0 _top_k: int = -1 # -1 is for all + @staticmethod + def _get(kwargs, key, default): + """Get value from kwargs, falling back to default when value is None or missing.""" + val = kwargs.get(key) + return val if val is not None else default + def init(self, tokenizer, **kwargs): super().__init__() - self.best_of = kwargs.get("best_of", 1) - self.n = kwargs.get("n", self.best_of) - self.do_sample = kwargs.get("do_sample", SamplingParams._do_sample) - self.presence_penalty = kwargs.get("presence_penalty", SamplingParams._presence_penalty) - self.frequency_penalty = kwargs.get("frequency_penalty", SamplingParams._frequency_penalty) - self.repetition_penalty = kwargs.get("repetition_penalty", SamplingParams._repetition_penalty) - self.temperature = kwargs.get("temperature", SamplingParams._temperature) - self.top_p = kwargs.get("top_p", SamplingParams._top_p) - self.top_k = kwargs.get("top_k", SamplingParams._top_k) - self.ignore_eos = kwargs.get("ignore_eos", False) - self.image_max_patch_num = kwargs.get("image_max_patch_num", -1) - self.max_new_tokens = kwargs.get("max_new_tokens", 16) - self.min_new_tokens = kwargs.get("min_new_tokens", 1) - self.input_penalty = kwargs.get("input_penalty", DEFAULT_INPUT_PENALTY) - self.group_request_id = kwargs.get("group_request_id", -1) - self.suggested_dp_index = kwargs.get("suggested_dp_index", -1) - - self.skip_special_tokens = kwargs.get("skip_special_tokens", SKIP_SPECIAL_TOKENS) - self.disable_prompt_cache = kwargs.get("disable_prompt_cache", False) + _get = SamplingParams._get + self.best_of = _get(kwargs, "best_of", 1) + self.n = _get(kwargs, "n", self.best_of) + self.do_sample = _get(kwargs, "do_sample", SamplingParams._do_sample) + self.presence_penalty = _get(kwargs, "presence_penalty", SamplingParams._presence_penalty) + self.frequency_penalty = _get(kwargs, "frequency_penalty", SamplingParams._frequency_penalty) + self.repetition_penalty = _get(kwargs, "repetition_penalty", SamplingParams._repetition_penalty) + self.temperature = _get(kwargs, "temperature", SamplingParams._temperature) + self.top_p = _get(kwargs, "top_p", SamplingParams._top_p) + self.top_k = _get(kwargs, "top_k", SamplingParams._top_k) + self.ignore_eos = _get(kwargs, "ignore_eos", False) + self.image_max_patch_num = _get(kwargs, "image_max_patch_num", -1) + self.max_new_tokens = _get(kwargs, "max_new_tokens", 16) + self.min_new_tokens = _get(kwargs, "min_new_tokens", 1) + self.input_penalty = _get(kwargs, "input_penalty", DEFAULT_INPUT_PENALTY) + self.group_request_id = _get(kwargs, "group_request_id", -1) + self.suggested_dp_index = _get(kwargs, "suggested_dp_index", -1) + + self.skip_special_tokens = _get(kwargs, "skip_special_tokens", SKIP_SPECIAL_TOKENS) + self.disable_prompt_cache = _get(kwargs, "disable_prompt_cache", False) self.add_special_tokens = kwargs.get("add_special_tokens", True) self.add_spaces_between_special_tokens = kwargs.get("add_spaces_between_special_tokens", True) diff --git a/lightllm/server/function_call_parser.py b/lightllm/server/function_call_parser.py index 9214715b1d..c3faf21e78 100644 --- a/lightllm/server/function_call_parser.py +++ b/lightllm/server/function_call_parser.py @@ -29,7 +29,15 @@ logger = logging.getLogger(__name__) -TOOLS_TAG_LIST = ["<|plugin|>", "", "<|python_tag|>", "[TOOL_CALLS]", "<|tool▁calls▁begin|>"] +TOOLS_TAG_LIST = [ + "<|plugin|>", + "", + "<|python_tag|>", + "[TOOL_CALLS]", + "<|tool▁calls▁begin|>", + "<|DSML|function_calls>", +] class ToolCallItem(BaseModel): @@ -1443,6 +1451,482 @@ def parse_streaming_increment(self, new_text: str, tools: List[Tool]) -> Streami return StreamingParseResult(normal_text="", calls=calls) +class DeepSeekV32Detector(BaseFormatDetector): + """ + Detector for DeepSeek V3.2 model function call format (DSML). + + DeepSeek V3.2 uses a new DSML (DeepSeek Markup Language) format for tool calls, + which is XML-like rather than JSON-based. + + Format Structure: + ``` + <|DSML|function_calls> + <|DSML|invoke name="get_weather"> + <|DSML|parameter name="location" string="true">杭州 + <|DSML|parameter name="date" string="true">2024-01-16 + <|DSML|invoke name="get_weather"> + <|DSML|parameter name="location" string="true">北京 + <|DSML|parameter name="date" string="true">2024-01-16 + ``` + + Key Components: + - Tool Calls Section: Starts with `<|DSML|function_calls>` + - Individual Invoke: `<|DSML|invoke name="function_name">` + - Parameters: `<|DSML|parameter name="param_name" string="true">value` + - Parameter types are inferred from the tool schema for proper JSON serialization + + Reference: https://huggingface.co/deepseek-ai/DeepSeek-V3.2 + """ + + def __init__(self): + super().__init__() + self.dsml_token = "|DSML|" + self.bot_token = "<|DSML|function_calls>" + self.eot_token = "" # DSML format has no explicit end token + self.invoke_prefix = '<|DSML|invoke name="' + self.parameter_prefix = '<|DSML|parameter name="' + + # Regex for complete parsing + self.invoke_regex = re.compile( + r'<|DSML|invoke\s+name="([^"]+)"\s*>(.*?)(?=<|DSML|invoke|$)', + re.DOTALL, + ) + # Captures: (param_name, is_string, value) + self.parameter_regex = re.compile( + r'<|DSML|parameter\s+name="([^"]+)"\s+string="(true|false)"\s*>(.*?)(?=<|DSML|parameter|<|DSML|invoke|$)', + re.DOTALL, + ) + + # Streaming state + self._last_arguments = "" + self._current_invoke_text = "" + self._invoke_count = 0 + self._param_count_in_invoke = 0 + self._accumulated_params: Dict[str, str] = {} + self._json_started = False + self._tools_schema: Optional[Dict[str, Dict]] = None + self._tool_indices: Optional[Dict[str, int]] = None + self._current_func_name: Optional[str] = None + self._in_tool_call_sequence = False # Set True once bot_token seen + + def has_tool_call(self, text: str) -> bool: + """Check if the text contains a DeepSeek V3.2 DSML format tool call.""" + return self.bot_token in text + + def _get_param_type(self, func_name: str, param_name: str, tools: List[Tool]) -> str: + """Get the JSON Schema type of a parameter from the tool definition.""" + if self._tools_schema is None: + self._tools_schema = {} + for tool in tools: + if tool.function.name and tool.function.parameters: + props = tool.function.parameters.get("properties", {}) + self._tools_schema[tool.function.name] = props + + func_schema = self._tools_schema.get(func_name, {}) + param_schema = func_schema.get(param_name, {}) + return param_schema.get("type", "string") + + def _convert_param_value(self, value: str, is_string_attr: str, param_type: str) -> Any: + """Convert a raw parameter value string to the appropriate Python type. + + Args: + value: The raw string value from the DSML parameter tag. + is_string_attr: The "string" attribute from DSML ("true" or "false"). + If "true", the value is treated as a raw string. + If "false", the value is parsed based on param_type or JSON. + param_type: The JSON Schema type from the tool definition (fallback). + """ + value = value.strip() + if value.lower() == "null": + return None + + # Use DSML string attribute as primary signal + if is_string_attr == "true": + return value + + # string="false" - parse based on schema type or attempt JSON + param_type = param_type.lower() + if param_type in ("integer", "int"): + try: + return int(value) + except (ValueError, TypeError): + return value + elif param_type in ("number", "float"): + try: + val = float(value) + # Only coerce to int if it's actually an integer string + if "." not in value and "e" not in value.lower(): + return int(value) + return val + except (ValueError, TypeError, OverflowError): + return value + elif param_type in ("boolean", "bool"): + lower = value.lower() + if lower in ("true", "1"): + return True + elif lower in ("false", "0"): + return False + else: + logger.warning(f"Unexpected boolean value: {value!r}, treating as string") + return value + elif param_type in ("object", "array"): + try: + return json.loads(value) + except json.JSONDecodeError: + return value + else: + # Unknown type with string="false" - try JSON parse, fallback to string + try: + return json.loads(value) + except json.JSONDecodeError: + return value + + def _parse_invoke_params(self, invoke_content: str, func_name: str, tools: List[Tool]) -> Dict: + """Parse all parameters from an invoke block content.""" + params = {} + for param_name, is_string_attr, param_value in self.parameter_regex.findall(invoke_content): + param_name = param_name.strip() + param_value = param_value.strip() + param_type = self._get_param_type(func_name, param_name, tools) + params[param_name] = self._convert_param_value(param_value, is_string_attr, param_type) + return params + + def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult: + """ + One-time parsing: Detects and parses DSML tool calls in the provided text. + """ + if self.bot_token not in text: + return StreamingParseResult(normal_text=text, calls=[]) + + idx = text.find(self.bot_token) + normal_text = text[:idx].strip() if idx > 0 else "" + tool_section = text[idx:] + + tool_indices = self._get_tool_indices(tools) + calls = [] + + try: + for func_name, invoke_content in self.invoke_regex.findall(tool_section): + func_name = func_name.strip() + if func_name not in tool_indices: + logger.warning(f"Model attempted to call undefined function: {func_name}") + continue + + params = self._parse_invoke_params(invoke_content, func_name, tools) + calls.append( + ToolCallItem( + tool_index=tool_indices[func_name], + name=func_name, + parameters=json.dumps(params, ensure_ascii=False), + ) + ) + return StreamingParseResult(normal_text=normal_text, calls=calls) + except Exception as e: + logger.error(f"Error in DeepSeekV32 detect_and_parse: {e}") + return StreamingParseResult(normal_text=text) + + def finalize_streaming(self, tools: List[Tool]) -> StreamingParseResult: + """Finalize the last pending tool call when generation ends (EOS). + + The DSML format has no explicit end token, so the last invoke's last + parameter may remain unconfirmed. This method should be called when + the stream ends to close any open JSON and emit remaining parameters. + """ + if not self.current_tool_name_sent or self.current_tool_id < 0: + return StreamingParseResult() + + calls: List[ToolCallItem] = [] + current_text = self._buffer + + try: + # Find current invoke text + invoke_positions = [] + search_start = 0 + while True: + pos = current_text.find(self.invoke_prefix, search_start) + if pos == -1: + break + invoke_positions.append(pos) + search_start = pos + len(self.invoke_prefix) + + if self._invoke_count < len(invoke_positions): + invoke_start = invoke_positions[self._invoke_count] + invoke_text = current_text[invoke_start:] + + name_content_start = len(self.invoke_prefix) + name_end = invoke_text.find('">', name_content_start) + if name_end != -1: + func_name = invoke_text[name_content_start:name_end].strip() + invoke_body = invoke_text[name_end + 2 :] + + # Parse all remaining params (including the last unconfirmed one) + param_matches = list(self.parameter_regex.finditer(invoke_body)) + for i in range(self._param_count_in_invoke, len(param_matches)): + match = param_matches[i] + param_name = match.group(1).strip() + is_string_attr = match.group(2) + param_value = match.group(3).strip() + + param_type = self._get_param_type(func_name, param_name, tools) + converted_value = self._convert_param_value(param_value, is_string_attr, param_type) + serialized_value = json.dumps(converted_value, ensure_ascii=False) + + if not self._json_started: + json_fragment = "{" + f'"{param_name}": {serialized_value}' + self._json_started = True + else: + json_fragment = f', "{param_name}": {serialized_value}' + + self._accumulated_params[param_name] = converted_value + calls.append( + ToolCallItem( + tool_index=self.current_tool_id, + name=None, + parameters=json_fragment, + ) + ) + self.streamed_args_for_tool[self.current_tool_id] += json_fragment + + # Close the JSON object + if self._json_started: + calls.append( + ToolCallItem( + tool_index=self.current_tool_id, + name=None, + parameters="}", + ) + ) + self.streamed_args_for_tool[self.current_tool_id] += "}" + elif self.current_tool_name_sent: + calls.append( + ToolCallItem( + tool_index=self.current_tool_id, + name=None, + parameters="{}", + ) + ) + self.streamed_args_for_tool[self.current_tool_id] = "{}" + + # Update prev_tool_call_arr + if self.current_tool_id < len(self.prev_tool_call_arr): + self.prev_tool_call_arr[self.current_tool_id]["arguments"] = self._accumulated_params + + # Reset state + self._invoke_count += 1 + self.current_tool_id += 1 + self.current_tool_name_sent = False + self._json_started = False + self._accumulated_params = {} + self._buffer = "" + + return StreamingParseResult(normal_text="", calls=calls) + except Exception as e: + logger.error(f"Error in DeepSeekV32 finalize_streaming: {e}") + return StreamingParseResult(normal_text="", calls=calls) + + def parse_streaming_increment(self, new_text: str, tools: List[Tool]) -> StreamingParseResult: + """ + Streaming incremental parsing for DeepSeek V3.2 DSML tool calls. + + The DSML format streams line-by-line with invoke/parameter tokens. + We accumulate parameters and only emit JSON fragments when a parameter's + value is confirmed complete (by seeing the next parameter/invoke boundary). + """ + self._buffer += new_text + current_text = self._buffer + + # Check if we have any DSML content + if not self._in_tool_call_sequence: + if not self.has_tool_call(current_text): + # Check for partial start token + if self._ends_with_partial_token(current_text, self.bot_token): + return StreamingParseResult() + self._buffer = "" + return StreamingParseResult(normal_text=new_text) + self._in_tool_call_sequence = True + + if self._tool_indices is None: + self._tool_indices = self._get_tool_indices(tools) + + calls: List[ToolCallItem] = [] + + try: + # Find all invoke starts in current buffer + invoke_positions = [] + search_start = 0 + while True: + pos = current_text.find(self.invoke_prefix, search_start) + if pos == -1: + break + invoke_positions.append(pos) + search_start = pos + len(self.invoke_prefix) + + if not invoke_positions: + # Have bot_token but no invoke yet - keep buffering + return StreamingParseResult() + + # Process only the current (latest) invoke block + current_invoke_idx = self._invoke_count + if current_invoke_idx >= len(invoke_positions): + # All invokes already processed, keep buffering for new ones + return StreamingParseResult() + + invoke_start = invoke_positions[current_invoke_idx] + # Whether the current invoke is bounded by a next invoke + invoke_is_bounded = current_invoke_idx + 1 < len(invoke_positions) + if invoke_is_bounded: + invoke_end = invoke_positions[current_invoke_idx + 1] + else: + invoke_end = len(current_text) + + invoke_text = current_text[invoke_start:invoke_end] + + # Extract function name + name_start = invoke_text.find(self.invoke_prefix) + if name_start == -1: + return StreamingParseResult() + + name_content_start = name_start + len(self.invoke_prefix) + name_end = invoke_text.find('">', name_content_start) + if name_end == -1: + # Function name not complete yet + return StreamingParseResult() + + func_name = invoke_text[name_content_start:name_end].strip() + + # Initialize state for this tool call + if self.current_tool_id == -1: + self.current_tool_id = 0 + self.prev_tool_call_arr = [] + self.streamed_args_for_tool = [""] + + while len(self.prev_tool_call_arr) <= self.current_tool_id: + self.prev_tool_call_arr.append({}) + while len(self.streamed_args_for_tool) <= self.current_tool_id: + self.streamed_args_for_tool.append("") + + # Send tool name if not sent yet + if not self.current_tool_name_sent: + if func_name and func_name in self._tool_indices: + calls.append( + ToolCallItem( + tool_index=self.current_tool_id, + name=func_name, + parameters="", + ) + ) + self.current_tool_name_sent = True + self.prev_tool_call_arr[self.current_tool_id] = { + "name": func_name, + "arguments": {}, + } + self._current_func_name = func_name + self._accumulated_params = {} + self._param_count_in_invoke = 0 + self._json_started = False + return StreamingParseResult(calls=calls) + return StreamingParseResult() + + # Parse parameters from the invoke block content + invoke_body = invoke_text[name_end + 2 :] # after '">' + + # Find all parameter starts within this invoke body + param_positions = [] + ps = 0 + while True: + pp = invoke_body.find(self.parameter_prefix, ps) + if pp == -1: + break + param_positions.append(pp) + ps = pp + len(self.parameter_prefix) + + # A parameter is "confirmed" when the next parameter/invoke boundary is visible, + # meaning the parameter's value won't grow further. + # For the last parameter in the invoke body, it's only confirmed if + # the invoke itself is bounded by a next invoke. + confirmed_count = 0 + for pi in range(len(param_positions)): + if pi + 1 < len(param_positions): + confirmed_count += 1 + elif invoke_is_bounded: + confirmed_count += 1 + + # Only emit newly confirmed parameters + if confirmed_count > self._param_count_in_invoke: + param_matches = list(self.parameter_regex.finditer(invoke_body)) + for i in range(self._param_count_in_invoke, min(confirmed_count, len(param_matches))): + match = param_matches[i] + param_name = match.group(1).strip() + is_string_attr = match.group(2) + param_value = match.group(3).strip() + + param_type = self._get_param_type(func_name, param_name, tools) + converted_value = self._convert_param_value(param_value, is_string_attr, param_type) + serialized_value = json.dumps(converted_value, ensure_ascii=False) + + if not self._json_started: + json_fragment = "{" + f'"{param_name}": {serialized_value}' + self._json_started = True + else: + json_fragment = f', "{param_name}": {serialized_value}' + + self._accumulated_params[param_name] = converted_value + + calls.append( + ToolCallItem( + tool_index=self.current_tool_id, + name=None, + parameters=json_fragment, + ) + ) + self.streamed_args_for_tool[self.current_tool_id] += json_fragment + + self._param_count_in_invoke = confirmed_count + + # Check if next invoke has started (meaning current one is complete) + if invoke_is_bounded: + # Current invoke is complete, close JSON and advance + if self._json_started: + close_fragment = "}" + calls.append( + ToolCallItem( + tool_index=self.current_tool_id, + name=None, + parameters=close_fragment, + ) + ) + self.streamed_args_for_tool[self.current_tool_id] += close_fragment + else: + calls.append( + ToolCallItem( + tool_index=self.current_tool_id, + name=None, + parameters="{}", + ) + ) + self.streamed_args_for_tool[self.current_tool_id] = "{}" + + # Update prev_tool_call_arr + self.prev_tool_call_arr[self.current_tool_id]["arguments"] = self._accumulated_params + + # Advance to next invoke, prune consumed buffer content + # Reset _invoke_count to 0 since buffer positions are now relative + self._buffer = current_text[invoke_end:] + self._invoke_count = 0 + self.current_tool_id += 1 + self.current_tool_name_sent = False + self._last_arguments = "" + self._accumulated_params = {} + self._param_count_in_invoke = 0 + self._json_started = False + + return StreamingParseResult(normal_text="", calls=calls) + + except Exception as e: + logger.error(f"Error in DeepSeekV32 parse_streaming_increment: {e}") + return StreamingParseResult(normal_text="", calls=calls) + + class FunctionCallParser: """ Parser for function/tool calls in model outputs. @@ -1455,6 +1939,7 @@ class FunctionCallParser: ToolCallParserEnum: Dict[str, Type[BaseFormatDetector]] = { "deepseekv3": DeepSeekV3Detector, "deepseekv31": DeepSeekV31Detector, + "deepseekv32": DeepSeekV32Detector, "glm47": Glm47Detector, "kimi_k2": KimiK2Detector, "llama3": Llama32Detector, @@ -1535,3 +2020,19 @@ def parse_stream_chunk(self, chunk_text: str) -> Tuple[str, list[ToolCallItem]]: final_normal_text = sp_result.normal_text return final_normal_text, final_calls + + def finalize_stream(self) -> Tuple[str, list[ToolCallItem]]: + """Finalize streaming when generation ends. + + For detectors that lack an explicit end-of-tool-call token (like DSML), + this closes any pending tool call JSON. For other detectors, this is a no-op. + + Returns: + A tuple of (normal_text, calls) like parse_stream_chunk. + """ + if not self.tools: + return "", [] + if hasattr(self.detector, "finalize_streaming"): + sp_result = self.detector.finalize_streaming(self.tools) + return sp_result.normal_text, sp_result.calls + return "", [] diff --git a/test/chat_template/tool_chat_template_deepseekv32.jinjia b/test/chat_template/tool_chat_template_deepseekv32.jinjia index b6d239dce7..7bb0fc375f 100644 --- a/test/chat_template/tool_chat_template_deepseekv32.jinjia +++ b/test/chat_template/tool_chat_template_deepseekv32.jinjia @@ -1,101 +1,202 @@ -{% if not add_generation_prompt is defined %} - {% set add_generation_prompt = false %} -{% endif %} -{% if not thinking is defined %} - {% set thinking = false %} -{% endif %} -{% set ns = namespace(is_first=false, is_tool=false, system_prompt='', is_first_sp=true, is_last_user=false, is_only_sys=false, is_prefix=false) %} -{%- for message in messages %} - {%- if message['role'] == 'system' %} - {%- if ns.is_first_sp %} - {% set ns.system_prompt = ns.system_prompt + message['content'] %} - {% set ns.is_first_sp = false %} - {%- else %} - {% set ns.system_prompt = ns.system_prompt + '\n\n' + message['content'] %} - {%- endif %} - {% set ns.is_only_sys = true %} - {%- endif %} -{%- endfor %} - -{% if tools is defined and tools is not none %} - {% set tool_ns = namespace(text='## Tools\nYou have access to the following tools:\n') %} - {% for tool in tools %} - {% set tool_ns.text = tool_ns.text + '\n### ' + tool.function.name + '\nDescription: ' + tool.function.description + '\n\nParameters: ' + (tool.function.parameters | tojson) + '\n' %} - {% endfor %} - {% set tool_ns.text = tool_ns.text + "\nIMPORTANT: ALWAYS adhere to this exact format for tool use:\n<|tool▁calls▁begin|><|tool▁call▁begin|>tool_call_name<|tool▁sep|>tool_call_arguments<|tool▁call▁end|>{{additional_tool_calls}}<|tool▁calls▁end|>\n\nWhere:\n\n- `tool_call_name` must be an exact match to one of the available tools\n- `tool_call_arguments` must be valid JSON that strictly follows the tool's Parameters Schema\n- For multiple tool calls, chain them directly without separators or spaces\n" %} - {% set ns.system_prompt = ns.system_prompt + '\n\n' + tool_ns.text %} -{% endif %} - -{{ bos_token }}{{ ns.system_prompt }} -{%- for message in messages %} - {%- if message['role'] == 'user' %} - {%- set ns.is_tool = false -%} - {%- set ns.is_first = false -%} - {%- set ns.is_last_user = true -%} - {{'<|User|>' + message['content']}} - {%- endif %} - {%- if message['role'] == 'assistant' and message['tool_calls'] is defined and message['tool_calls'] is not none %} - {%- if ns.is_last_user or ns.is_only_sys %} - {{'<|Assistant|>'}} - {%- endif %} - {%- set ns.is_last_user = false -%} - {%- set ns.is_first = false %} - {%- set ns.is_tool = false -%} - {%- for tool in message['tool_calls'] %} - {%- set formatted_args = tool['function']['arguments'] if tool['function']['arguments'] is string else tool['function']['arguments']|tojson %} - {%- if not ns.is_first %} - {%- if message['content'] is none %} - {{'<|tool▁calls▁begin|><|tool▁call▁begin|>'+ tool['function']['name'] + '<|tool▁sep|>' + formatted_args + '<|tool▁call▁end|>'}} - {%- else %} - {{message['content'] + '<|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['function']['name'] + '<|tool▁sep|>' + formatted_args + '<|tool▁call▁end|>'}} - {%- endif %} - {%- set ns.is_first = true -%} - {%- else %} - {{'<|tool▁call▁begin|>'+ tool['function']['name'] + '<|tool▁sep|>' + formatted_args + '<|tool▁call▁end|>'}} - {%- endif %} - {%- endfor %} - {{'<|tool▁calls▁end|><|end▁of▁sentence|>'}} - {%- endif %} - {%- if message['role'] == 'assistant' and (message['tool_calls'] is not defined or message['tool_calls'] is none) %} - {%- if ns.is_last_user %} - {{'<|Assistant|>'}} - {%- if message['prefix'] is defined and message['prefix'] and thinking %} - {{''}} - {%- else %} - {{''}} - {%- endif %} - {%- endif %} - {%- if message['prefix'] is defined and message['prefix'] %} - {%- set ns.is_prefix = true -%} - {%- endif %} - {%- set ns.is_last_user = false -%} - {%- if ns.is_tool %} - {{message['content'] + '<|end▁of▁sentence|>'}} - {%- set ns.is_tool = false -%} - {%- else %} - {%- set content = message['content'] -%} - {%- if '' in content %} - {%- set content = content.split('', 1)[1] -%} - {%- endif %} - {{content + '<|end▁of▁sentence|>'}} - {%- endif %} - {%- endif %} - {%- if message['role'] == 'tool' %} - {%- set ns.is_last_user = false -%} - {%- set ns.is_tool = true -%} - {{'<|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}} - {%- endif %} - {%- if message['role'] != 'system' %} - {% set ns.is_only_sys = false %} - {%- endif %} +{#- ============================================================================ + DeepSeek-V3.2 DSML Chat Template + Converted from encoding_dsv32.py encode_messages function. + Uses DSML (DeepSeek Markup Language) format for tool calls. + ============================================================================ -#} +{%- set bos_token = "<|begin▁of▁sentence|>" -%} +{%- set eos_token = "<|end▁of▁sentence|>" -%} +{%- set thinking_start_token = "" -%} +{%- set thinking_end_token = "" -%} +{%- set dsml_token = "|DSML|" -%} + +{%- set system_msg_template = "{content}" -%} +{%- set user_msg_template = "<|User|>{content}<|Assistant|>" -%} +{%- set assistant_msg_template = "{reasoning}{content}{tool_calls}<|end▁of▁sentence|>" -%} +{%- set thinking_template = "{reasoning_content}" -%} +{%- set tool_call_template = "<{dsml_token}invoke name=\"{name}\">\n{arguments}\n{dsml_token}invoke>" -%} +{%- set tool_calls_template = "<{dsml_token}function_calls>\n{tool_calls}\n{dsml_token}function_calls>" -%} +{%- set tool_output_template = "\n{content}" -%} + +{%- set TOOLS_SYSTEM_TEMPLATE -%} +## Tools +You have access to a set of tools you can use to answer the user's question. +You can invoke functions by writing a "<{{ dsml_token }}function_calls>" block like the following as part of your reply to the user: + +<{{ dsml_token }}function_calls> +<{{ dsml_token }}invoke name="$FUNCTION_NAME"> +<{{ dsml_token }}parameter name="$PARAMETER_NAME" string="true|false">$PARAMETER_VALUE{{ dsml_token }}parameter> +... +{{ dsml_token }}invoke> +<{{ dsml_token }}invoke name="$FUNCTION_NAME2"> +... +{{ dsml_token }}invoke> +{{ dsml_token }}function_calls> + +String and scalar parameters should be specified as is without any escaping or quotes, while lists and objects should use JSON format. The "string" attribute should be set to "true" for string type parameters and "false" for other types (numbers, booleans, arrays, objects). + +Here are the functions available in JSONSchema format: +{tool_schemas} +{%- endset -%} + +{%- if thinking_mode is not defined -%} + {%- set thinking_mode = "thinking" -%} +{%- endif -%} +{%- if drop_thinking is not defined -%} + {%- set drop_thinking = true -%} +{%- endif -%} +{%- if add_default_bos_token is not defined -%} + {%- set add_default_bos_token = true -%} +{%- endif -%} + +{#- Macro: encode_arguments_to_dsml -#} +{%- macro encode_arguments_to_dsml(arguments) -%} + {%- set ns = namespace(P_dsml_strs=[]) -%} + {%- if arguments is mapping -%} + {%- for k, v in arguments.items() -%} + {%- if v is string -%} + {%- set is_str = "true" -%} + {%- set value = v -%} + {%- else -%} + {%- set is_str = "false" -%} + {%- set value = v | tojson -%} + {%- endif -%} + {%- set p_dsml_str = "<" ~ dsml_token ~ "parameter name=\"" ~ k ~ "\" string=\"" ~ is_str ~ "\">" ~ value ~ dsml_token ~ "parameter>" -%} + {%- set ns.P_dsml_strs = ns.P_dsml_strs + [p_dsml_str] -%} + {%- endfor -%} + {%- endif -%} + {{- ns.P_dsml_strs | join("\n") -}} +{%- endmacro -%} + +{#- Macro: render_tools -#} +{%- macro render_tools(tools) -%} + {%- set ns = namespace(tools_json=[]) -%} + {%- for tool in tools -%} + {%- if tool.function is defined -%} + {%- set ns.tools_json = ns.tools_json + [tool.function | tojson] -%} + {%- else -%} + {%- set ns.tools_json = ns.tools_json + [tool | tojson] -%} + {%- endif -%} + {%- endfor -%} + {{- TOOLS_SYSTEM_TEMPLATE | replace("{tool_schemas}", ns.tools_json | join("\n")) }} +{% endmacro -%} + +{#- Macro: find_last_user_index -#} +{%- macro find_last_user_index(messages) -%} + {%- set ns = namespace(last_user_index=-1) -%} + {%- for msg in messages -%} + {%- set role = msg.role if msg.role is defined else msg.get('role') -%} + {%- if role in ['user', 'developer'] -%} + {%- set ns.last_user_index = loop.index0 -%} + {%- endif -%} + {%- endfor -%} + {{- ns.last_user_index -}} +{%- endmacro -%} + +{#- Macro: render_tool_calls_content -#} +{%- macro render_tool_calls_content(tool_calls) -%} + {%- set ns = namespace(formatted_calls=[]) -%} + {%- for tool_call in tool_calls -%} + {%- if tool_call.function is defined -%} + {%- set name = tool_call.function.name -%} + {%- set arguments = tool_call.function.arguments -%} + {%- else -%} + {%- set name = tool_call.name -%} + {%- set arguments = tool_call.arguments -%} + {%- endif -%} + {%- if arguments is string -%} + {%- set arguments = arguments | fromjson -%} + {%- endif -%} + {%- set formatted_call = "<" ~ dsml_token ~ "invoke name=\"" ~ name ~ "\">\n" ~ encode_arguments_to_dsml(arguments) ~ "\n" ~ dsml_token ~ "invoke>" -%} + {%- set ns.formatted_calls = ns.formatted_calls + [formatted_call] -%} + {%- endfor -%} + {{- "<" ~ dsml_token ~ "function_calls>\n" ~ ns.formatted_calls | join("\n") ~ "\n" ~ dsml_token ~ "function_calls>" -}} +{%- endmacro -%} + +{#- Macro: render_message -#} +{%- macro render_message(index, messages, thinking_mode) -%} + {%- set msg = messages[index] -%} + {%- set last_user_idx = find_last_user_index(messages) | int -%} + {%- set role = msg.role if msg.role is defined else msg.get('role') -%} + {%- set content = msg.content if msg.content is defined else (msg.get('content', '') or '') -%} + {%- set msg_tools = msg.tools if msg.tools is defined else msg.get('tools', []) -%} + {%- set tool_calls = msg.tool_calls if msg.tool_calls is defined else msg.get('tool_calls', []) -%} + {%- set reasoning_content = msg.reasoning_content if msg.reasoning_content is defined else (msg.get('reasoning_content', '') or '') -%} + + {%- if role == 'system' -%} + {{- content or '' -}} + {%- if msg_tools -%} + {{- "\n\n" ~ render_tools(msg_tools) -}} + {%- endif -%} + + {%- elif role == 'user' -%} + {{- "<|User|>" ~ content ~ "<|Assistant|>" -}} + {%- if index == last_user_idx and thinking_mode == "thinking" -%} + {{- thinking_start_token -}} + {%- else -%} + {{- thinking_end_token -}} + {%- endif -%} + + {%- elif role == 'tool' -%} + {%- set ns = namespace(prev_assistant_idx=-1) -%} + {%- for i in range(index - 1, -1, -1) -%} + {%- set check_role = messages[i].role if messages[i].role is defined else messages[i].get('role') -%} + {%- if check_role != 'tool' and ns.prev_assistant_idx == -1 -%} + {%- set ns.prev_assistant_idx = i -%} + {%- endif -%} + {%- endfor -%} + {%- set tool_call_order = index - ns.prev_assistant_idx -%} + {%- set assistant_msg = messages[ns.prev_assistant_idx] -%} + {%- set assistant_tool_calls = assistant_msg.tool_calls if assistant_msg.tool_calls is defined else assistant_msg.get('tool_calls', []) -%} + {%- if tool_call_order == 1 -%} + {{- "\n\n" -}} + {%- endif -%} + {{- "\n" ~ content -}} + {%- if tool_call_order == (assistant_tool_calls | length) -%} + {{- "\n" -}} + {%- if index >= last_user_idx and thinking_mode == "thinking" -%} + {{- "\n\n" ~ thinking_start_token -}} + {%- else -%} + {{- "\n\n" ~ thinking_end_token -}} + {%- endif -%} + {%- endif -%} + + {%- elif role == 'assistant' -%} + {%- set ns = namespace(thinking_part="", tool_calls_content="") -%} + {%- if tool_calls -%} + {%- set ns.tool_calls_content = "\n\n" ~ render_tool_calls_content(tool_calls) -%} + {%- endif -%} + {%- set summary_content = content or "" -%} + {%- if thinking_mode == "thinking" and index > last_user_idx -%} + {%- set ns.thinking_part = reasoning_content ~ thinking_end_token -%} + {%- endif -%} + {{- ns.thinking_part ~ summary_content ~ ns.tool_calls_content ~ "<|end▁of▁sentence|>" -}} + {%- endif -%} +{%- endmacro -%} + +{#- Main template body -#} +{%- set full_messages = messages -%} + +{#- Handle tools in top-level (OpenAI format) -#} +{%- if tools is defined and tools is not none -%} + {%- set ns_sys = namespace(has_system=false, sys_idx=-1) -%} + {%- for msg in full_messages -%} + {%- set role = msg.role if msg.role is defined else msg.get('role') -%} + {%- if role == 'system' and not ns_sys.has_system -%} + {%- set ns_sys.has_system = true -%} + {%- set ns_sys.sys_idx = loop.index0 -%} + {%- endif -%} + {%- endfor -%} +{%- endif -%} + +{%- if add_default_bos_token -%} + {{- bos_token -}} +{%- endif -%} + +{#- If tools defined at top level but no system message has them, prepend tools info -#} +{%- if tools is defined and tools is not none -%} + {{- render_tools(tools) -}} +{%- endif -%} + +{%- for msg in full_messages -%} + {{- render_message(loop.index0, full_messages, thinking_mode) -}} {%- endfor -%} -{% if add_generation_prompt and not ns.is_tool%} - {% if ns.is_last_user or ns.is_only_sys or not ns.is_prefix %} - {{'<|Assistant|>'}} - {%- if not thinking %} - {{''}} - {%- else %} - {{''}} - {%- endif %} - {% endif %} -{% endif %} diff --git a/test/test_api/test_gsmk.py b/test/test_api/test_gsmk.py new file mode 100644 index 0000000000..2d9ead65b8 --- /dev/null +++ b/test/test_api/test_gsmk.py @@ -0,0 +1,265 @@ +# Adapted from https://github.com/sgl-project/sglang/blob/main/benchmark/gsm8k/bench_other.py +import argparse +import ast +import json +import os +import re +import time +from concurrent.futures import ThreadPoolExecutor +from typing import Optional + +import numpy as np +import requests +from tqdm import tqdm + +INVALID = -9999999 + +SYSTEM_PROMPT_TARGET_LEN = 18192 + + +def generate_system_prompt(): + """Generate a system prompt of approximately 8192 characters.""" + base = ( + "You are a highly capable math assistant. Your task is to solve grade school math problems step by step. " + "Show your reasoning clearly and provide the final numerical answer. " + "Break down each problem into smaller steps and verify your calculations. " + "Always end your answer with the format: #### . " + ) + # Repeat base text to reach target length + repeats = SYSTEM_PROMPT_TARGET_LEN // len(base) + 1 + prompt = (base * repeats)[:SYSTEM_PROMPT_TARGET_LEN] + return prompt + + +def read_jsonl(filename: str): + """Read a JSONL file.""" + with open(filename) as fin: + for line in fin: + if line.startswith("#"): + continue + yield json.loads(line) + + +def dump_state_text(filename: str, states: list, mode: str = "w"): + """Dump program state in a text file.""" + with open(filename, mode) as fout: + for i, s in enumerate(states): + if isinstance(s, str): + fout.write(f"==== {i} ====\n{s}\n") + else: + fout.write(f"==== {i} ====\n{str(s)}\n") + + +def download_and_cache_file(url: str, filename: Optional[str] = None): + """Read and cache a file from a url.""" + if filename is None: + filename = os.path.join("/tmp", url.split("/")[-1]) + + # Check if the cache file already exists + if os.path.exists(filename): + return filename + + print(f"Downloading from {url} to {filename}") + + # Stream the response to show the progress bar + response = requests.get(url, stream=True) + response.raise_for_status() # Check for request errors + + # Total size of the file in bytes + total_size = int(response.headers.get("content-length", 0)) + chunk_size = 1024 # Download in chunks of 1KB + + # Use tqdm to display the progress bar + with open(filename, "wb") as file, tqdm( + desc="Downloading", + total=total_size, + unit="iB", + unit_scale=True, + unit_divisor=1024, + ) as bar: + for chunk in response.iter_content(chunk_size=chunk_size): + size = file.write(chunk) + bar.update(size) + + return filename + + +def call_generate_lightllm(prompt, temperature, max_tokens, stop=None, url=None): + """Call LightLLM API for text generation.""" + assert url is not None + + data = { + "inputs": prompt, + "parameters": { + "temperature": temperature, + "max_new_tokens": max_tokens, + "stop_sequences": stop, + "repetition_penalty": 1.0, + "top_p": 1.0, + "top_k": 1, + }, + } + res = requests.post(url, json=data) + assert res.status_code == 200, f"API request failed with status code {res.status_code}: {res.text}" + + response_json = res.json() + if "generated_text" not in response_json: + raise ValueError(f"Invalid API response format. Expected 'generated_text' key, got: {response_json.keys()}") + if not isinstance(response_json["generated_text"], list) or len(response_json["generated_text"]) == 0: + raise ValueError( + "Invalid API response format. 'generated_text' should be a non-empty list, " + f"got: {response_json['generated_text']}" + ) + + pred = response_json["generated_text"][0] + return pred + + +def get_one_example(lines, i, include_answer): + ret = "Question: " + lines[i]["question"] + "\nAnswer:" + if include_answer: + ret += " " + lines[i]["answer"] + return ret + + +def get_few_shot_examples(lines, k): + ret = "" + for i in range(k): + ret += get_one_example(lines, i, True) + "\n\n" + return ret + + +def get_answer_value(answer_str): + answer_str = answer_str.replace(",", "") + # First try to find the answer after "####" marker (GSM8K format) + match = re.search(r"####\s*(-?\d+)", answer_str) + if match: + try: + return ast.literal_eval(match.group(1)) + except SyntaxError: + pass + # Fallback: find all numbers and take the last one + numbers = re.findall(r"\d+", answer_str) + if len(numbers) < 1: + return INVALID + try: + return ast.literal_eval(numbers[-1]) + except SyntaxError: + return INVALID + + +def parse_args(): + """Parse command line arguments.""" + parser = argparse.ArgumentParser() + parser.add_argument("--parallel", type=int, default=256) + parser.add_argument("--host", type=str, default="http://127.0.0.1") + parser.add_argument("--port", type=int, default=8000) + parser.add_argument("--num-shots", type=int, default=5) + parser.add_argument("--num-questions", type=int, default=200) + parser.add_argument("--result-file", type=str, default="result.jsonl") + parser.add_argument("--data-path", type=str, default="test.jsonl") + parser.add_argument( + "--system-prompt", action="store_true", help="Prepend an 8192-character system prompt to each request" + ) + return parser.parse_args() + + +def main(args): + # LightLLM API URL + url = f"{args.host}:{args.port}/generate" + + # Read data + url_data = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl" + filename = download_and_cache_file(url_data) + lines = list(read_jsonl(filename)) + + # Construct prompts + num_questions = args.num_questions + num_shots = args.num_shots + few_shot_examples = get_few_shot_examples(lines, num_shots) + + system_prefix = "" + if args.system_prompt: + system_prefix = generate_system_prompt() + "\n\n" + print(f"System prompt enabled: {len(system_prefix)} characters") + + # Ensure we have enough samples and avoid data leakage + # Test questions should start after few-shot examples + max_available = len(lines) - num_shots + if num_questions > max_available: + print( + "Warning: Requested {} questions, but only {} available after reserving {} for few-shot. " + "Using {} questions.".format(num_questions, max_available, num_shots, max_available) + ) + num_questions = max_available + + questions = [] + labels = [] + for i in range(num_shots, num_shots + num_questions): + questions.append(get_one_example(lines, i, False)) + labels.append(get_answer_value(lines[i]["answer"])) + assert all(label != INVALID for label in labels) + + states = [None] * len(labels) + + # Run requests using thread pool + def get_one_answer(i): + answer = call_generate_lightllm( + prompt=system_prefix + few_shot_examples + questions[i], + temperature=0, + max_tokens=1024, + stop=["Question", "Assistant:", "<|separator|>", "Human:", "\n\nQuestion"], + url=url, + ) + states[i] = answer + + tic = time.perf_counter() + if args.parallel == 1: + for i in tqdm(range(len(questions))): + get_one_answer(i) + else: + with ThreadPoolExecutor(args.parallel) as executor: + list( + tqdm( + executor.map(get_one_answer, list(range(len(questions)))), + total=len(questions), + ) + ) + + latency = time.perf_counter() - tic + + preds = [] + for i in range(len(states)): + preds.append(get_answer_value(states[i])) + + # Compute accuracy + acc = np.mean(np.array(preds) == np.array(labels)) + invalid = np.mean(np.array(preds) == INVALID) + + # Print results + print(f"Accuracy: {acc:.3f}") + print(f"Invalid: {invalid:.3f}") + print(f"Latency: {latency:.3f} s") + + # Dump results + dump_state_text("tmp_output_lightllm.txt", states) + + with open(args.result_file, "a") as fout: + value = { + "task": "gsm8k", + "backend": "lightllm", + "num_gpus": 1, + "latency": round(latency, 3), + "accuracy": round(acc, 3), + "num_requests": args.num_questions, + "other": { + "num_questions": args.num_questions, + "parallel": args.parallel, + }, + } + fout.write(json.dumps(value) + "\n") + + +if __name__ == "__main__": + args = parse_args() + main(args) From e51d4cedd09cf029957540b1437d4bcdcd77c292 Mon Sep 17 00:00:00 2001 From: sufubao Date: Wed, 4 Feb 2026 08:39:18 +0000 Subject: [PATCH 16/58] fix --- .../deepseek3_2/layer_infer/nsa_indexer_layer_inder.py | 3 +-- .../deepseek3_2/layer_infer/transformer_layer_infer.py | 6 +++--- .../layer_weights/nsa_indexer_layer_weight.py | 4 ++-- .../layer_weights/transformer_layer_weight.py | 10 +++------- 4 files changed, 9 insertions(+), 14 deletions(-) diff --git a/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py b/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py index 3855bf590f..61b4962f15 100644 --- a/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py +++ b/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py @@ -16,11 +16,10 @@ class NSAIndexerInfer(BaseLayerInfer): - def __init__(self, layer_idx, network_config, mode=[]): + def __init__(self, layer_idx, network_config): super().__init__() self.layer_idx_ = layer_idx self.network_config_ = network_config - self.mode = mode self.index_topk = network_config["index_topk"] self.tp_q_head_num_ = network_config["num_attention_heads"] // self.tp_world_size_ self.tp_k_head_num_ = 1 diff --git a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py index bc8bb9c6ba..b7326c36eb 100644 --- a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py @@ -14,11 +14,11 @@ class Deepseek3_2TransformerLayerInfer(Deepseek2TransformerLayerInfer): - def __init__(self, layer_num, network_config, mode=[]): + def __init__(self, layer_num, network_config): self.index_topk = network_config["index_topk"] - super().__init__(layer_num, network_config, mode) + super().__init__(layer_num, network_config) - self.indexer = NSAIndexerInfer(layer_idx=self.layer_num_, network_config=self.network_config_, mode=mode) + self.indexer = NSAIndexerInfer(layer_idx=self.layer_num_, network_config=self.network_config_) self.topk_indices = None # Initialize NSA attention backend (singleton, lazy initialization) diff --git a/lightllm/models/deepseek3_2/layer_weights/nsa_indexer_layer_weight.py b/lightllm/models/deepseek3_2/layer_weights/nsa_indexer_layer_weight.py index 9ccfbe97ed..9e1337b0fa 100644 --- a/lightllm/models/deepseek3_2/layer_weights/nsa_indexer_layer_weight.py +++ b/lightllm/models/deepseek3_2/layer_weights/nsa_indexer_layer_weight.py @@ -7,8 +7,8 @@ class NSAIndexerWeight(TransformerLayerWeight): - def __init__(self, layer_num, data_type, network_config, mode, quant_cfg): - super().__init__(layer_num, data_type, network_config, mode, quant_cfg) + def __init__(self, layer_num, data_type, network_config, quant_cfg): + super().__init__(layer_num, data_type, network_config, quant_cfg) return @override diff --git a/lightllm/models/deepseek3_2/layer_weights/transformer_layer_weight.py b/lightllm/models/deepseek3_2/layer_weights/transformer_layer_weight.py index 2a03e1d6a1..adcba51cc9 100644 --- a/lightllm/models/deepseek3_2/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/deepseek3_2/layer_weights/transformer_layer_weight.py @@ -3,14 +3,10 @@ class Deepseek3_2TransformerLayerWeight(Deepseek2TransformerLayerWeight): - def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None): + def __init__(self, layer_num, data_type, network_config, quant_cfg=None): self.index_topk = network_config["index_topk"] - super().__init__(layer_num, data_type, network_config, mode, quant_cfg) + super().__init__(layer_num, data_type, network_config, quant_cfg) self.indexer_layer_weight = NSAIndexerWeight( - layer_num=layer_num, - data_type=data_type, - network_config=network_config, - mode=mode, - quant_cfg=quant_cfg + layer_num=layer_num, data_type=data_type, network_config=network_config, quant_cfg=quant_cfg ) return From 7d8be57b57f3f4fbffcbbdd44fad418e16d14589 Mon Sep 17 00:00:00 2001 From: sufubao Date: Wed, 4 Feb 2026 08:45:23 +0000 Subject: [PATCH 17/58] fix --- .../deepseek3_2/layer_infer/nsa_indexer_layer_inder.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py b/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py index 61b4962f15..390853271a 100644 --- a/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py +++ b/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py @@ -1,8 +1,6 @@ from sgl_kernel import fast_topk_transform_fused import deep_gemm import torch -import torch.nn.functional as F - from lightllm.common.basemodel.layer_infer.base_layer_infer import BaseLayerInfer from lightllm.models.deepseek3_2.layer_weights.nsa_indexer_layer_weight import NSAIndexerWeight from lightllm.models.deepseek3_2.infer_struct import Deepseek3_2FlashAttentionStateInfo @@ -138,10 +136,7 @@ def _get_q_k_bf16( q = layer_weight.wq_b_proj_.mm(q_lora).view(-1, self.index_n_heads, self.index_head_dim) k = layer_weight.wk_proj_.mm(hidden_states) - # TODO - k = F.layer_norm( - k.float(), (self.index_head_dim,), layer_weight.k_norm_.weight, layer_weight.k_norm_.bias, self.eps - ).type_as(k) + k = layer_weight.k_norm_(k, eps=self.eps) # Slice position_cos and position_sin to match actual token length actual_seq_len = q.shape[0] From da09156c4bc46a282d95f1087e70c01e169bc10c Mon Sep 17 00:00:00 2001 From: sufubao Date: Wed, 4 Feb 2026 09:50:12 +0000 Subject: [PATCH 18/58] fix --- .../attention/nsa/flashmla_sparse.py | 7 +- lightllm/common/infer_utils.py | 9 +- .../triton_kernel/fp8_mqa_logits.py | 124 +++-- .../triton_kernel/token_group_quant.py | 9 +- lightllm/server/build_prompt.py | 50 +- lightllm/server/encoding_dsv32.py | 429 ++++++++++++++++++ 6 files changed, 564 insertions(+), 64 deletions(-) create mode 100644 lightllm/server/encoding_dsv32.py diff --git a/lightllm/common/basemodel/attention/nsa/flashmla_sparse.py b/lightllm/common/basemodel/attention/nsa/flashmla_sparse.py index 3eec98f055..2c347ed32b 100644 --- a/lightllm/common/basemodel/attention/nsa/flashmla_sparse.py +++ b/lightllm/common/basemodel/attention/nsa/flashmla_sparse.py @@ -1,3 +1,6 @@ +# Adapted from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/attention/nsa_backend.py +# Uses sgl_kernel.flash_mla and sgl_kernel.flash_attn from the sglang kernel library. + import dataclasses import torch from typing import Tuple, TYPE_CHECKING @@ -112,8 +115,8 @@ def _nsa_decode_att( q_nope, q_rope = q # Extract k_rope and kv_nope from the KV buffer - k_rope = kv[:, :, -qk_rope_head_dim:].reshape(-1, 1, 1, qk_rope_head_dim) - kv_nope = kv[:, :, :-qk_rope_head_dim].reshape(-1, 1, 1, kv_lora_rank) + k_rope = kv[:, :, -qk_rope_head_dim:].view(-1, 1, 1, qk_rope_head_dim) + kv_nope = kv[:, :, :-qk_rope_head_dim].view(-1, 1, 1, kv_lora_rank) o_tensor = flash_attn_with_kvcache( q=q_rope, diff --git a/lightllm/common/infer_utils.py b/lightllm/common/infer_utils.py index 26cf973be4..e1b9cc3830 100644 --- a/lightllm/common/infer_utils.py +++ b/lightllm/common/infer_utils.py @@ -1,15 +1,8 @@ -import torch from lightllm.common.basemodel.triton_kernel.copy_kv_index_to_req import copy_kv_index_to_req_prefill def init_req_to_token_indexes( - req_to_token_indexs, - b_req_idx, - b_seq_len, - b_ready_cache_len, - b_start_loc, - alloc_mem_index, - max_q_seq_len, + req_to_token_indexs, b_req_idx, b_seq_len, b_ready_cache_len, b_start_loc, alloc_mem_index, max_q_seq_len ): copy_kv_index_to_req_prefill( req_to_token_indexs=req_to_token_indexs, diff --git a/lightllm/models/deepseek3_2/triton_kernel/fp8_mqa_logits.py b/lightllm/models/deepseek3_2/triton_kernel/fp8_mqa_logits.py index 2fc92662af..e8f1bbfa21 100644 --- a/lightllm/models/deepseek3_2/triton_kernel/fp8_mqa_logits.py +++ b/lightllm/models/deepseek3_2/triton_kernel/fp8_mqa_logits.py @@ -1,3 +1,4 @@ +# import triton import triton.language as tl import torch @@ -5,13 +6,27 @@ @triton.jit def _fp8_paged_mqa_logits_kernel( - Q_ptr, KV_ptr, KVScale_ptr, Weights_ptr, MemIndex_ptr, - CuSeqlenKs_ptr, CuSeqlenKe_ptr, Output_ptr, - seq_len, seq_len_kv, num_heads, head_dim, - stride_q_seq, stride_q_head, stride_q_dim, - stride_kv_pool, stride_kv_dim, - stride_w_seq, stride_w_head, - stride_o_seq, stride_o_kv, + Q_ptr, + KV_ptr, + KVScale_ptr, + Weights_ptr, + MemIndex_ptr, + CuSeqlenKs_ptr, + CuSeqlenKe_ptr, + Output_ptr, + seq_len, + seq_len_kv, + num_heads, + head_dim, + stride_q_seq, + stride_q_head, + stride_q_dim, + stride_kv_pool, + stride_kv_dim, + stride_w_seq, + stride_w_head, + stride_o_seq, + stride_o_kv, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_D: tl.constexpr, @@ -19,80 +34,79 @@ def _fp8_paged_mqa_logits_kernel( pid_m = tl.program_id(0) pid_n = tl.program_id(1) - + # Compute the range of seq positions this block handles start_m = pid_m * BLOCK_SIZE_M start_n = pid_n * BLOCK_SIZE_N - + # Offset arrays for this block offs_m = start_m + tl.arange(0, BLOCK_SIZE_M) offs_n = start_n + tl.arange(0, BLOCK_SIZE_N) - + # Initialize accumulator for logits logits = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - + # Create masks mask_m = offs_m < seq_len mask_n = offs_n < seq_len_kv - + # Load mem_indices for the KV positions mem_indices = tl.load(MemIndex_ptr + offs_n, mask=mask_n, other=0) - + # Load scales for K scales = tl.load(KVScale_ptr + mem_indices, mask=mask_n, other=1.0) - + # Loop over all heads for h in range(num_heads): # Load weights for this head - weights = tl.load(Weights_ptr + offs_m * stride_w_seq + h * stride_w_head, - mask=mask_m, other=0.0) - + weights = tl.load(Weights_ptr + offs_m * stride_w_seq + h * stride_w_head, mask=mask_m, other=0.0) + # Initialize score accumulator for this head score = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - + # Loop over head_dim in blocks for d_block in range(tl.cdiv(head_dim, BLOCK_SIZE_D)): d_start = d_block * BLOCK_SIZE_D offs_d = d_start + tl.arange(0, BLOCK_SIZE_D) mask_d = offs_d < head_dim - + # Load Q for this head and dimension block # Q shape: (seq_len, num_heads, head_dim) q_ptrs = Q_ptr + offs_m[:, None] * stride_q_seq + h * stride_q_head + offs_d[None, :] * stride_q_dim mask_q = (offs_m[:, None] < seq_len) & mask_d[None, :] q = tl.load(q_ptrs, mask=mask_q, other=0.0).to(tl.float32) - + # Load K for this dimension block # KV shape: (pool_size, head_dim) as FP8 data k_ptrs = KV_ptr + mem_indices[:, None] * stride_kv_pool + offs_d[None, :] * stride_kv_dim mask_k = mask_n[:, None] & mask_d[None, :] k = tl.load(k_ptrs, mask=mask_k, other=0.0).to(tl.float32) - + # Apply scale to K (scale is per-row of K) k = k * scales[:, None] - + # Compute partial dot product: q @ k.T # q: (BLOCK_SIZE_M, BLOCK_SIZE_D), k: (BLOCK_SIZE_N, BLOCK_SIZE_D) # score: (BLOCK_SIZE_M, BLOCK_SIZE_N) score += tl.dot(q, tl.trans(k)) - + # Apply ReLU to score score = tl.maximum(score, 0.0) - + # Multiply by weights and accumulate to logits logits += score * weights[:, None] - + # Apply mask based on cu_seqlen_ks and cu_seqlen_ke mask_ks = tl.load(CuSeqlenKs_ptr + offs_m, mask=mask_m, other=0) mask_ke = tl.load(CuSeqlenKe_ptr + offs_m, mask=mask_m, other=seq_len_kv) - + mask_lo = offs_n[None, :] >= mask_ks[:, None] mask_hi = offs_n[None, :] < mask_ke[:, None] mask_valid = mask_lo & mask_hi & mask_m[:, None] & mask_n[None, :] - + # Apply mask (-inf for masked positions) - logits = tl.where(mask_valid, logits, float('-inf')) - + logits = tl.where(mask_valid, logits, float("-inf")) + # Store output out_ptrs = Output_ptr + offs_m[:, None] * stride_o_seq + offs_n[None, :] * stride_o_kv mask_out = (offs_m[:, None] < seq_len) & (offs_n[None, :] < seq_len_kv) @@ -100,40 +114,54 @@ def _fp8_paged_mqa_logits_kernel( def fp8_paged_mqa_logits( - q: torch.Tensor, + q: torch.Tensor, kv: torch.Tensor, kv_scale: torch.Tensor, - weights: torch.Tensor, - mem_index: torch.Tensor, - cu_seqlen_ks: torch.Tensor, + weights: torch.Tensor, + mem_index: torch.Tensor, + cu_seqlen_ks: torch.Tensor, cu_seqlen_ke: torch.Tensor, - out: torch.Tensor = None + out: torch.Tensor = None, ) -> torch.Tensor: seq_len, num_heads, head_dim = q.shape seq_len_kv = mem_index.shape[0] - + if out is None: output = torch.empty((seq_len, seq_len_kv), device=q.device, dtype=torch.float32) else: output = out - + BLOCK_SIZE_M = 16 BLOCK_SIZE_N = 64 - BLOCK_SIZE_D = 128 - + BLOCK_SIZE_D = 128 + grid = (triton.cdiv(seq_len, BLOCK_SIZE_M), triton.cdiv(seq_len_kv, BLOCK_SIZE_N)) - + _fp8_paged_mqa_logits_kernel[grid]( - q, kv, kv_scale, weights, mem_index, - cu_seqlen_ks, cu_seqlen_ke, output, - seq_len, seq_len_kv, num_heads, head_dim, - q.stride(0), q.stride(1), q.stride(2), - kv.stride(0), kv.stride(1), - weights.stride(0), weights.stride(1), - output.stride(0), output.stride(1), + q, + kv, + kv_scale, + weights, + mem_index, + cu_seqlen_ks, + cu_seqlen_ke, + output, + seq_len, + seq_len_kv, + num_heads, + head_dim, + q.stride(0), + q.stride(1), + q.stride(2), + kv.stride(0), + kv.stride(1), + weights.stride(0), + weights.stride(1), + output.stride(0), + output.stride(1), BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_D=BLOCK_SIZE_D, ) - - return output \ No newline at end of file + + return output diff --git a/lightllm/models/deepseek3_2/triton_kernel/token_group_quant.py b/lightllm/models/deepseek3_2/triton_kernel/token_group_quant.py index dbf5c51992..8079864133 100644 --- a/lightllm/models/deepseek3_2/triton_kernel/token_group_quant.py +++ b/lightllm/models/deepseek3_2/triton_kernel/token_group_quant.py @@ -1,3 +1,5 @@ +# Adapted from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/quantization/fp8_kernel.py + import triton import triton.language as tl import torch @@ -7,6 +9,7 @@ fp8_max = 448.0 fp8_dtype = torch.float8_e4m3fn + @triton.jit def _per_token_group_quant_mla_deep_gemm_masked_fp8( y_ptr, @@ -46,9 +49,7 @@ def _per_token_group_quant_mla_deep_gemm_masked_fp8( mask = cols < group_size for gid in range(NUM_GROUP): - y = tl.load(y_ptr + gid * group_size + cols, mask=mask, other=0.0).to( - tl.float32 - ) + y = tl.load(y_ptr + gid * group_size + cols, mask=mask, other=0.0).to(tl.float32) _absmax = tl.maximum(tl.max(tl.abs(y)), eps) y_s = _absmax / fp8_max y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) @@ -100,4 +101,4 @@ def per_token_group_quant_mla_deep_gemm_masked_fp8( BLOCK_SIZE, ) - return x_q, x_s.transpose(1, 2), masked_m, m, aligned_m \ No newline at end of file + return x_q, x_s.transpose(1, 2), masked_m, m, aligned_m diff --git a/lightllm/server/build_prompt.py b/lightllm/server/build_prompt.py index f770459a55..d77184863f 100644 --- a/lightllm/server/build_prompt.py +++ b/lightllm/server/build_prompt.py @@ -1,11 +1,28 @@ +import json +import os + tokenizer = None +_model_type = None def init_tokenizer(args): - global tokenizer + global tokenizer, _model_type from lightllm.server.tokenizer import get_tokenizer tokenizer = get_tokenizer(args.model_dir, args.tokenizer_mode, trust_remote_code=args.trust_remote_code) + + # Detect model type for specialized encoding (e.g. DeepSeek-V3.2) + config_path = os.path.join(args.model_dir, "config.json") + if os.path.exists(config_path): + with open(config_path, "r", encoding="utf-8") as f: + model_config = json.load(f) + _model_type = model_config.get("model_type", None) + # Check architectures as fallback + if _model_type is None: + archs = model_config.get("architectures", []) + if any("DeepseekV32" in a for a in archs): + _model_type = "deepseek_v32" + chat_path = args.chat_template if chat_path is not None: with open(chat_path, "r", encoding="utf-8") as f: @@ -14,9 +31,14 @@ def init_tokenizer(args): async def build_prompt(request, tools) -> str: - global tokenizer + global tokenizer, _model_type # pydantic格式转成dict, 否则,当根据tokenizer_config.json拼template时,Jinja判断无法识别 messages = [m.model_dump(by_alias=True, exclude_none=True) for m in request.messages] + + # Use DeepSeek-V3.2 native encoding when applicable + if _model_type == "deepseek_v32": + return _build_prompt_dsv32(messages, tools, request) + kwargs = {"conversation": messages} if request.character_settings: kwargs["character_settings"] = request.character_settings @@ -40,3 +62,27 @@ async def build_prompt(request, tools) -> str: tools=tools, ) return input_str + + +def _build_prompt_dsv32(messages, tools, request): + from lightllm.server.encoding_dsv32 import encode_messages + + # Inject tools into system message if present + if tools is not None and len(tools) > 0: + wrapped_tools = [t if "function" in t else {"function": t} for t in tools] + if messages and messages[0].get("role") == "system": + messages[0]["tools"] = wrapped_tools + else: + messages.insert(0, {"role": "system", "tools": wrapped_tools}) + + # Determine thinking mode from request + thinking = False + if request.chat_template_kwargs: + thinking = request.chat_template_kwargs.get("thinking", False) or request.chat_template_kwargs.get( + "enable_thinking", False + ) + + thinking_mode = "thinking" if thinking else "chat" + drop_thinking = messages[-1]["role"] == "user" if messages else True + + return encode_messages(messages, thinking_mode=thinking_mode, drop_thinking=drop_thinking) diff --git a/lightllm/server/encoding_dsv32.py b/lightllm/server/encoding_dsv32.py new file mode 100644 index 0000000000..3ac4b83714 --- /dev/null +++ b/lightllm/server/encoding_dsv32.py @@ -0,0 +1,429 @@ +# Adapted from vLLM's deepseek_v32_encoding.py +# (https://github.com/vllm-project/vllm), which was originally adapted from +# https://huggingface.co/deepseek-ai/DeepSeek-V3.2/blob/main/encoding/encoding_dsv32.py +# +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import copy +import json +import re +from typing import Any + +# flake8: noqa: E501 +TOOLS_SYSTEM_TEMPLATE = """## Tools +You have access to a set of tools you can use to answer the user's question. +You can invoke functions by writing a "<{dsml_token}function_calls>" block like the following as part of your reply to the user: +<{dsml_token}function_calls> +<{dsml_token}invoke name="$FUNCTION_NAME"> +<{dsml_token}parameter name="$PARAMETER_NAME" string="true|false">$PARAMETER_VALUE +... + +<{dsml_token}invoke name="$FUNCTION_NAME2"> +... + + +String and scalar parameters should be specified as is without any escaping or quotes, while lists and objects should use JSON format. The "string" attribute should be set to "true" for string type parameters and "false" for other types (numbers, booleans, arrays, objects). +If the thinking_mode is enabled, then after function results you should strongly consider outputting a thinking block. Here is an example: +<{dsml_token}function_calls> +... + + +... + +{thinking_start_token}...thinking about results{thinking_end_token} +Here are the functions available in JSONSchema format: + +{tool_schemas} + +""" + +bos_token: str = "<|begin▁of▁sentence|>" +eos_token: str = "<|end▁of▁sentence|>" +thinking_start_token: str = "" +thinking_end_token: str = "" +dsml_token: str = "|DSML|" +system_msg_template: str = "{content}" +user_msg_template: str = "<|User|>{content}<|Assistant|>" +assistant_msg_template: str = "{reasoning}{content}{tool_calls}<|end▁of▁sentence|>" +thinking_template = "{reasoning}" + +response_format_template: str = ( + "## Response Format:\n\nYou MUST strictly adhere to the following schema to reply:\n{schema}" +) +tool_call_template: str = '<{dsml_token}invoke name="{name}">\n{arguments}\n' +tool_calls_template = "<{dsml_token}function_calls>\n{tool_calls}\n" + +tool_output_template: str = "\n{content}" + + +def to_json(value: Any) -> str: + try: + return json.dumps(value, ensure_ascii=False) + except Exception: + return json.dumps(value, ensure_ascii=True) + + +def tools_from_openai_format(tools): + return [tool["function"] for tool in tools] + + +def tool_calls_from_openai_format(tool_calls): + return [ + { + "name": tool_call["function"]["name"], + "arguments": tool_call["function"]["arguments"], + } + for tool_call in tool_calls + ] + + +def tool_calls_to_openai_format(tool_calls): + return [ + { + "type": "function", + "function": { + "name": tool_call["name"], + "arguments": tool_call["arguments"], + }, + } + for tool_call in tool_calls + ] + + +def encode_arguments_to_dsml(tool_call: dict) -> str: + p_dsml_template = """<{dsml_token}parameter name="{key}" string="{is_str}">{value}""" + P_dsml_strs = [] + if isinstance(tool_call["arguments"], str): + arguments = json.loads(tool_call["arguments"]) + else: + arguments = tool_call["arguments"] + + for k, v in arguments.items(): + p_dsml_str = p_dsml_template.format( + dsml_token=dsml_token, + key=k, + is_str="true" if isinstance(v, str) else "false", + value=v if isinstance(v, str) else to_json(v), + ) + + P_dsml_strs.append(p_dsml_str) + + return "\n".join(P_dsml_strs) + + +def decode_dsml_to_arguments(tool_name, tool_args): + def _decode_value(key, value, string): + if string == "true": + value = to_json(value) + return f"{to_json(key)}: {value}" + + tool_args_json = "{" + ", ".join([_decode_value(k, v, string=is_str) for k, (v, is_str) in tool_args.items()]) + "}" + return dict(name=tool_name, arguments=tool_args_json) + + +def render_tools(tools): + tools_json = [to_json(t) for t in tools] + + return TOOLS_SYSTEM_TEMPLATE.format( + tool_schemas="\n".join(tools_json), + dsml_token=dsml_token, + thinking_start_token=thinking_start_token, + thinking_end_token=thinking_end_token, + ) + + +def find_last_user_index(messages): + last_user_index = -1 + for idx in range(len(messages) - 1, -1, -1): + if messages[idx].get("role") in ["user", "developer"]: + last_user_index = idx + break + return last_user_index + + +def render_message(index, messages, thinking_mode): + if not (0 <= index < len(messages)): + raise ValueError(f"Index {index} out of range for messages list of length {len(messages)}") + if thinking_mode not in ["chat", "thinking"]: + raise ValueError(f"Invalid thinking_mode `{thinking_mode}`") + + prompt = "" + msg = messages[index] + last_user_idx = find_last_user_index(messages) + + role = msg.get("role") + content = msg.get("content") + tools = msg.get("tools") + response_format = msg.get("response_format") + tool_calls = msg.get("tool_calls") + reasoning = msg.get("reasoning") + is_prefix = msg.get("prefix", False) + + if tools: + tools = tools_from_openai_format(tools) + if tool_calls: + tool_calls = tool_calls_from_openai_format(tool_calls) + + if role == "system": + prompt += system_msg_template.format(content=content or "") + if tools: + prompt += "\n\n" + render_tools(tools) + + if response_format: + prompt += "\n\n" + response_format_template.format(schema=to_json(response_format)) + + elif role == "developer": + if not content: + raise ValueError(f"Invalid message for role `{role}`: {msg}") + content_developer = "" + if tools: + content_developer += "\n\n" + render_tools(tools) + + if response_format: + content_developer += "\n\n" + response_format_template.format(schema=to_json(response_format)) + + content_developer += "\n\n# The user's message is: {}".format(content) + + prompt += user_msg_template.format(content=content_developer) + if index == last_user_idx and thinking_mode == "thinking": + prompt += thinking_start_token + else: + prompt += thinking_end_token + + elif role == "user": + prompt += user_msg_template.format(content=content) + + if index == last_user_idx and thinking_mode == "thinking": + prompt += thinking_start_token + else: + prompt += thinking_end_token + + elif role == "tool": + prev_assistant_idx = index - 1 + assistant_msg = messages[prev_assistant_idx] + while prev_assistant_idx >= 0 and assistant_msg.get("role") == "tool": + prev_assistant_idx -= 1 + assistant_msg = messages[prev_assistant_idx] + + if not (index == 0 or prev_assistant_idx >= 0 and assistant_msg.get("role") == "assistant"): + raise ValueError(f"Invalid messages at {index}:\n{assistant_msg}") + + tool_call_order = index - prev_assistant_idx + assistant_tool_calls = assistant_msg.get("tool_calls") + if not (assistant_tool_calls and len(assistant_tool_calls) >= tool_call_order): + raise ValueError("No tool calls but found tool output") + + if tool_call_order == 1: + prompt += "\n\n" + + prompt += tool_output_template.format(content=content) + + if tool_call_order == len(assistant_tool_calls): + prompt += "\n" + + if index >= last_user_idx and thinking_mode == "thinking": + prompt += "\n\n" + thinking_start_token + else: + prompt += "\n\n" + thinking_end_token + + elif role == "assistant": + thinking_part = "" + + tool_calls_content = "" + if tool_calls: + tool_calls = [ + tool_call_template.format( + dsml_token=dsml_token, + name=tool_call.get("name"), + arguments=encode_arguments_to_dsml(tool_call), + ) + for tool_call in tool_calls + ] + tool_calls_content += "\n\n" + tool_calls_template.format( + dsml_token=dsml_token, tool_calls="\n".join(tool_calls) + ) + + summary_content = content or "" + + if thinking_mode == "thinking" and index > last_user_idx: + if not (reasoning or tool_calls): + raise ValueError( + f"ThinkingMode: {thinking_mode}, invalid message without reasoning/tool_calls `{msg}` after last user message" + ) + thinking_part = thinking_template.format(reasoning=reasoning or "") + thinking_end_token + + if not tool_calls and is_prefix: + prompt += summary_content + else: + prompt += assistant_msg_template.format( + reasoning=thinking_part, + content=summary_content, + tool_calls=tool_calls_content, + ) + else: + raise NotImplementedError(f"Unknown role: {role}") + + return prompt + + +def drop_thinking_messages(messages, last_user_idx=None): + messages_wo_thinking = [] + last_user_idx = find_last_user_index(messages) if last_user_idx is None else last_user_idx + for idx, msg in enumerate(messages): + role = msg.get("role") + if role in ["user", "system", "tool"] or idx >= last_user_idx: + messages_wo_thinking.append(msg) + continue + + elif role == "assistant": + msg_wo_thinking = copy.copy(msg) + msg_wo_thinking.pop("reasoning", None) + messages_wo_thinking.append(msg_wo_thinking) + + return messages_wo_thinking + + +def encode_messages( + messages, + thinking_mode, + context=None, + drop_thinking=True, + add_default_bos_token=True, +): + context = context if context else [] + full_messages = context + messages + + prompt = bos_token if add_default_bos_token and len(context) == 0 else "" + + if thinking_mode == "thinking" and drop_thinking: + full_messages = drop_thinking_messages(full_messages) + + for idx in range(len(messages)): + prompt += render_message(idx + len(context), full_messages, thinking_mode=thinking_mode) + + return prompt + + +def _read_until_stop(index, text, stop): + min_pos = len(text) + matched_stop = None + + for s in stop: + pos = text.find(s, index) + if pos != -1 and pos < min_pos: + min_pos = pos + matched_stop = s + + if matched_stop: + content = text[index:min_pos] + return min_pos + len(matched_stop), content, matched_stop + else: + content = text[index:] + return len(text), content, None + + +def parse_tool_calls(index, text): + tool_calls = [] + stop_token = None + tool_calls_end_token = f"" + + while index < len(text): + index, _, stop_token = _read_until_stop(index, text, [f"<{dsml_token}invoke", tool_calls_end_token]) + if _ != ">\n": + raise RuntimeError("Tool call format error") + + if stop_token == tool_calls_end_token: + break + + if stop_token is None: + raise RuntimeError("Missing special token") + + index, tool_name_content, stop_token = _read_until_stop( + index, text, [f"<{dsml_token}parameter", f"\n$', tool_name_content, flags=re.DOTALL) + if len(p_tool_name) != 1: + raise RuntimeError("Tool name format error") + tool_name = p_tool_name[0] + + tool_args = {} + while stop_token == f"<{dsml_token}parameter": + index, param_content, stop_token = _read_until_stop(index, text, [f"/{dsml_token}parameter"]) + + param_kv = re.findall( + r'^ name="(.*?)" string="(true|false)">(.*?)<$', + param_content, + flags=re.DOTALL, + ) + if len(param_kv) != 1: + raise RuntimeError("Parameter format error") + param_name, string, param_value = param_kv[0] + + if param_name in tool_args: + raise RuntimeError("Duplicate parameter name") + tool_args[param_name] = (param_value, string) + + index, content, stop_token = _read_until_stop( + index, text, [f"<{dsml_token}parameter", f"\n": + raise RuntimeError("Parameter format error") + + tool_call = decode_dsml_to_arguments(tool_name=tool_name, tool_args=tool_args) + tool_calls.append(tool_call) + + return index, stop_token, tool_calls + + +# NOTE: This function is designed to parse only correctly +# formatted string and will not attempt to correct malformed output +# that may be generated by the model. +def parse_message_from_completion_text(text, thinking_mode): + summary_content, reasoning, tool_calls = "", "", [] + index, stop_token = 0, None + tool_calls_start_token = f"\n\n<{dsml_token}function_calls" + + is_thinking, is_tool_calling = thinking_mode == "thinking", False + + if is_thinking: + index, content_delta, stop_token = _read_until_stop(index, text, [thinking_end_token, tool_calls_start_token]) + reasoning = content_delta + if stop_token != thinking_end_token: + raise RuntimeError("Invalid thinking format") + + index, content_delta, stop_token = _read_until_stop(index, text, [eos_token, tool_calls_start_token]) + summary_content = content_delta + if stop_token == tool_calls_start_token: + is_tool_calling = True + else: + if stop_token != eos_token: + raise RuntimeError("Invalid summary format") + + if is_tool_calling: + index, stop_token, tool_calls = parse_tool_calls(index, text) + + index, tool_ends_text, stop_token = _read_until_stop(index, text, [eos_token]) + if tool_ends_text: + raise RuntimeError("Unexpected content after tool calls") + + if not (len(text) == index and stop_token in [eos_token, None]): + raise RuntimeError("Unexpected content at end") + + for sp_token in [ + bos_token, + eos_token, + thinking_start_token, + thinking_end_token, + dsml_token, + ]: + if sp_token in summary_content or sp_token in reasoning: + raise RuntimeError("Unexpected special token in content") + + return { + "role": "assistant", + "content": summary_content, + "reasoning": reasoning, + "tool_calls": tool_calls_to_openai_format(tool_calls), + } From 512e32c0abe8b86a151e296e4e0330adc005dcd4 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Wed, 4 Mar 2026 01:45:00 +0000 Subject: [PATCH 19/58] rebase --- .../fused_moe/fused_moe_weight.py | 2 +- .../deepseek3_2}/encoding_dsv32.py | 0 lightllm/models/deepseek3_2/infer_struct.py | 142 +++-- .../layer_infer/nsa_indexer_layer_inder.py | 8 +- .../layer_infer/transformer_layer_infer.py | 11 +- .../layer_weights/nsa_indexer_layer_weight.py | 35 +- lightllm/models/deepseek3_2/mem_manager.py | 6 - lightllm/models/deepseek3_2/model.py | 104 +++- .../destindex_copy_indexer_ks.py | 168 ----- .../triton_kernel/extract_indexer_ks.py | 265 +------- .../triton_kernel/fp8_mqa_logits.py | 26 - lightllm/server/api_cli.py | 9 +- lightllm/server/api_openai.py | 33 - lightllm/server/build_prompt.py | 50 +- lightllm/server/core/objs/sampling_params.py | 45 +- lightllm/server/function_call_parser.py | 586 ++++++------------ lightllm/server/tokenizer.py | 11 + 17 files changed, 458 insertions(+), 1043 deletions(-) rename lightllm/{server => models/deepseek3_2}/encoding_dsv32.py (100%) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py index 6bcf7fc03c..8f54e14a72 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py @@ -98,7 +98,7 @@ def _init_parallel_params(self): self.split_inter_size = self.moe_intermediate_size // self.tp_world_size_ if self.enable_ep_moe: assert self.num_fused_shared_experts == 0, "num_fused_shared_experts must be 0 when enable_ep_moe" - logger.info( + logger.debug( f"global_rank {self.global_rank_} layerindex {self.layer_num_} " f"redundancy_expertids: {self.redundancy_expert_ids}" ) diff --git a/lightllm/server/encoding_dsv32.py b/lightllm/models/deepseek3_2/encoding_dsv32.py similarity index 100% rename from lightllm/server/encoding_dsv32.py rename to lightllm/models/deepseek3_2/encoding_dsv32.py diff --git a/lightllm/models/deepseek3_2/infer_struct.py b/lightllm/models/deepseek3_2/infer_struct.py index e0cca499bd..779c2fc2d2 100644 --- a/lightllm/models/deepseek3_2/infer_struct.py +++ b/lightllm/models/deepseek3_2/infer_struct.py @@ -4,7 +4,7 @@ from lightllm.models.deepseek3_2.mem_manager import Deepseek3_2MemoryManager -class Deepseek3_2FlashAttentionStateInfo(Deepseek2InferStateInfo): +class Deepseek3_2InferStateInfo(Deepseek2InferStateInfo): _shared_nsa_buffers = None def __init__(self): @@ -54,7 +54,7 @@ def _check_use_cuda_graph_buffers(self): and hasattr(model, "graph_max_batch_size") and hasattr(model, "graph_max_len_in_batch") and self.batch_size <= model.graph_max_batch_size - and self.max_len_in_batch <= model.graph_max_len_in_batch + and self.max_kv_seq_len <= model.graph_max_len_in_batch ): return True return False @@ -68,12 +68,13 @@ def init_some_extra_state(self, model): self.indexer_ks_buffer = self.mem_manager.indexer_ks_buffer if self.is_prefill: - pass + self._init_nsa_indexing_prefill() else: if self.b_ready_cache_len is None: self.b_ready_cache_len = torch.zeros_like(self.b_seq_len) use_cuda_graph_buffers = self._check_use_cuda_graph_buffers() + buffer = None if use_cuda_graph_buffers: buffers = self.get_nsa_buffers(model.graph_max_batch_size, model.graph_max_len_in_batch) @@ -84,86 +85,125 @@ def init_some_extra_state(self, model): self.nsa_cache_seqlens = torch.empty(self.batch_size, dtype=torch.int32, device="cuda") self.nsa_cu_seqlens_k = torch.empty(self.batch_size + 1, dtype=torch.int32, device="cuda") - self.nsa_cache_seqlens.copy_(self.b_att_seq_len.clamp(max=self.index_topk)) + self.nsa_cache_seqlens.copy_(self.b_kv_seq_len.clamp(max=self.index_topk)) assert self.nsa_cache_seqlens.dtype == torch.int32 torch.cumsum(self.nsa_cache_seqlens, dim=0, dtype=torch.int32, out=self.nsa_cu_seqlens_k[1:]) self.nsa_cu_seqlens_k[0] = 0 - self._init_nsa_indexing_structures() + self._init_nsa_indexing_decode(use_cuda_graph_buffers, buffer) - def _init_nsa_indexing_structures(self): - """Pre-compute ks, ke, lengths, and page_table_size_1 for NSA indexer. + def _init_nsa_indexing_decode(self, use_cuda_graph_buffers, buffer): + """Optimized NSA indexing for decode: b_q_seq_len=1 per request. - Fully vectorized: eliminates per-request .item() CPU-GPU syncs. + In decode, each request generates exactly 1 token, so: + - total_q_len = batch_size (no .item() needed) + - ks[i] = cumsum_offset[i], ke[i] = cumsum_offset[i] + 1 + - lengths[i] = b_seq_len[i] + - No repeat_interleave, no token_in_req math needed. """ b_seq_len = self.b_seq_len + b_req_idx = self.b_req_idx + num_seq = self.batch_size + + # Cumulative seq_len offsets for ks/ke: [0, s0, s0+s1, ...] + cum_seq = torch.cumsum(b_seq_len, dim=0, dtype=torch.int32) + + if use_cuda_graph_buffers: + model = self._model_ref() + max_seq_len = model.graph_max_len_in_batch + + # ks, ke, lengths — write directly into buffer slices + buf_ks = buffer["ks"][:num_seq] + buf_ke = buffer["ke"][:num_seq] + buf_lengths = buffer["lengths"][:num_seq] + + # ks[0] = 0, ks[i] = cum_seq[i-1] + buf_ks[0] = 0 + if num_seq > 1: + buf_ks[1:].copy_(cum_seq[: num_seq - 1]) + # ke = ks + 1 + torch.add(buf_ks, 1, out=buf_ke) + # lengths = b_seq_len + buf_lengths.copy_(b_seq_len.int()) + + self.ks = buf_ks + self.ke = buf_ke + self.lengths = buf_lengths + + # page_table: zero buffer slice, then fill valid entries + page_table = buffer["page_table_size_1"][:num_seq, :max_seq_len] + page_table.zero_() + all_rows = self.req_manager.req_to_token_indexs[b_req_idx, :max_seq_len] + seq_range = torch.arange(max_seq_len, device=b_seq_len.device) + valid_mask = seq_range.unsqueeze(0) < b_seq_len.unsqueeze(1) + page_table[valid_mask] = all_rows[valid_mask].int() + self.page_table_size_1 = page_table + + # req_all_mem_index: use padded [num_seq * max_seq_len] layout + # Downstream uses ks/ke masking so padded entries are safe + max_total_seq = num_seq * max_seq_len + buf_mem = buffer["req_all_mem_index"][:max_total_seq] + buf_mem.copy_(all_rows.reshape(-1)) + self.req_all_mem_index = buf_mem + else: + # Non-CUDA-graph decode: simplified formulas, fresh tensors + max_seq_len = b_seq_len.max().item() + + # ks/ke/lengths + seq_offsets = torch.empty_like(cum_seq) + seq_offsets[0] = 0 + if num_seq > 1: + seq_offsets[1:] = cum_seq[:-1] + + self.ks = seq_offsets + self.ke = (seq_offsets + 1).int() + self.lengths = b_seq_len.int() + + # page_table and req_all_mem_index + all_rows = self.req_manager.req_to_token_indexs[b_req_idx, :max_seq_len] + seq_range = torch.arange(max_seq_len, device=b_seq_len.device) + valid_mask = seq_range.unsqueeze(0) < b_seq_len.unsqueeze(1) + + page_table = torch.zeros((num_seq, max_seq_len), dtype=torch.int, device=b_seq_len.device) + page_table[valid_mask] = all_rows[valid_mask].int() + self.page_table_size_1 = page_table + + self.req_all_mem_index = all_rows[valid_mask] + + def _init_nsa_indexing_prefill(self): + """NSA indexing for prefill: variable q lengths, generic vectorized path.""" + b_seq_len = self.b_seq_len b_q_seq_len = self.b_q_seq_len b_req_idx = self.b_req_idx num_seq = b_req_idx.shape[0] device = b_seq_len.device - # Only 3 scalar syncs needed (for tensor shapes) max_seq_len = b_seq_len.max().item() total_q_len = b_q_seq_len.sum().item() - total_seq_len = b_seq_len.sum().item() - # --- page_table_size_1 and req_all_mem_index (vectorized gather) --- + # page_table_size_1 and req_all_mem_index all_rows = self.req_manager.req_to_token_indexs[b_req_idx, :max_seq_len] seq_range = torch.arange(max_seq_len, device=device) valid_mask = seq_range.unsqueeze(0) < b_seq_len.unsqueeze(1) - # page_table_size_1: [batch, max_seq_len] zero-padded memory indices page_table = torch.zeros((num_seq, max_seq_len), dtype=torch.int, device=device) page_table[valid_mask] = all_rows[valid_mask].int() + self.page_table_size_1 = page_table + self.req_all_mem_index = all_rows[valid_mask] - # req_all_mem_index: flattened valid memory indices across all requests - req_all_mem_index = all_rows[valid_mask] - - # --- ks, ke, lengths (vectorized computation) --- - # Cumulative seq_len offsets: [0, seq_len[0], seq_len[0]+seq_len[1], ...] + # ks, ke, lengths — generic vectorized for variable q lengths cum_seq = torch.cumsum(b_seq_len, dim=0) seq_offsets = torch.zeros_like(cum_seq) seq_offsets[1:] = cum_seq[:-1] - # Expand per-request values to per-token using repeat_interleave req_indices = torch.repeat_interleave(torch.arange(num_seq, device=device), b_q_seq_len) - # Token position within each request's q_seq cum_q = torch.cumsum(b_q_seq_len, dim=0) q_offsets = torch.zeros_like(cum_q) q_offsets[1:] = cum_q[:-1] token_in_req = torch.arange(total_q_len, device=device) - q_offsets[req_indices] - # ks[t] = seq_offset of request owning token t - # ke[t] = seq_offset + position_in_q + 1 - # lengths[t] = seq_len - q_seq_len + position_in_q + 1 - ks = seq_offsets[req_indices].int() - ke = (seq_offsets[req_indices] + token_in_req + 1).int() - lengths = (b_seq_len[req_indices] - b_q_seq_len[req_indices] + token_in_req + 1).int() - - # --- Assign results (CUDA graph buffer or new tensors) --- - use_cuda_graph_buffers = self._check_use_cuda_graph_buffers() - - if use_cuda_graph_buffers: - model = self._model_ref() - buffers = self.get_nsa_buffers(model.graph_max_batch_size, model.graph_max_len_in_batch) - buffer = buffers[self.microbatch_index] - - self.ks = buffer["ks"][:total_q_len] - self.ke = buffer["ke"][:total_q_len] - self.lengths = buffer["lengths"][:total_q_len] - self.page_table_size_1 = buffer["page_table_size_1"][:num_seq, :max_seq_len] - self.req_all_mem_index = buffer["req_all_mem_index"][:total_seq_len] - - self.ks.copy_(ks) - self.ke.copy_(ke) - self.lengths.copy_(lengths) - self.page_table_size_1.copy_(page_table) - self.req_all_mem_index.copy_(req_all_mem_index) - else: - self.ks = ks - self.ke = ke - self.lengths = lengths - self.page_table_size_1 = page_table - self.req_all_mem_index = req_all_mem_index + self.ks = seq_offsets[req_indices].int() + self.ke = (seq_offsets[req_indices] + token_in_req + 1).int() + self.lengths = (b_seq_len[req_indices] - b_q_seq_len[req_indices] + token_in_req + 1).int() diff --git a/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py b/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py index 390853271a..7a9aeb46c9 100644 --- a/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py +++ b/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py @@ -3,7 +3,7 @@ import torch from lightllm.common.basemodel.layer_infer.base_layer_infer import BaseLayerInfer from lightllm.models.deepseek3_2.layer_weights.nsa_indexer_layer_weight import NSAIndexerWeight -from lightllm.models.deepseek3_2.infer_struct import Deepseek3_2FlashAttentionStateInfo +from lightllm.models.deepseek3_2.infer_struct import Deepseek3_2InferStateInfo from lightllm.models.deepseek2.triton_kernel.rotary_emb import rotary_emb_fwd from lightllm.models.deepseek3_2.triton_kernel.act_quant import act_quant from lightllm.models.deepseek3_2.triton_kernel.destindex_copy_indexer_ks import destindex_copy_indexer_ks @@ -70,7 +70,7 @@ def get_indices( self, hidden_states: torch.Tensor, q_lora: torch.Tensor, - infer_state: Deepseek3_2FlashAttentionStateInfo, + infer_state: Deepseek3_2InferStateInfo, layer_weight: NSAIndexerWeight, ) -> torch.Tensor: @@ -113,7 +113,7 @@ def get_indices( score=logits, lengths=lengths, page_table_size_1=page_table_1, - cu_seqlens_q=infer_state.cu_seqlens_q, + cu_seqlens_q=infer_state.b1_cu_q_seq_len, topk=self.index_topk, ) @@ -130,7 +130,7 @@ def _get_q_k_bf16( self, hidden_states: torch.Tensor, q_lora: torch.Tensor, - infer_state: Deepseek3_2FlashAttentionStateInfo, + infer_state: Deepseek3_2InferStateInfo, layer_weight: NSAIndexerWeight, ): q = layer_weight.wq_b_proj_.mm(q_lora).view(-1, self.index_n_heads, self.index_head_dim) diff --git a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py index b7326c36eb..9dba923cc1 100644 --- a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py @@ -5,7 +5,7 @@ from lightllm.models.deepseek2.layer_infer.transformer_layer_infer import Deepseek2TransformerLayerInfer from lightllm.models.deepseek3_2.layer_infer.nsa_indexer_layer_inder import NSAIndexerInfer from lightllm.models.deepseek3_2.layer_weights.transformer_layer_weight import Deepseek3_2TransformerLayerWeight -from lightllm.models.deepseek3_2.infer_struct import Deepseek3_2FlashAttentionStateInfo +from lightllm.models.deepseek3_2.infer_struct import Deepseek3_2InferStateInfo from lightllm.models.deepseek3_2.triton_kernel.token_group_quant import per_token_group_quant_mla_deep_gemm_masked_fp8 from lightllm.common.basemodel.triton_kernel.norm.rmsnorm import rmsnorm_forward from lightllm.models.deepseek2.triton_kernel.rotary_emb import rotary_emb_fwd @@ -37,7 +37,7 @@ def _get_nsa_backend(self): def _get_qkv( self, input: torch.Tensor, - infer_state: Deepseek3_2FlashAttentionStateInfo, + infer_state: Deepseek3_2InferStateInfo, layer_weight: Deepseek3_2TransformerLayerWeight, ) -> torch.Tensor: input = input.view(-1, self.embed_dim_) @@ -47,9 +47,6 @@ def _get_qkv( ) q = rmsnorm_forward(q, weight=layer_weight.q_a_layernorm_.weight, eps=self.eps_) - # Process all tokens for indexer - # Note: Prefix cache slicing optimization is disabled due to batch structure - # mismatch issues with fast_topk_transform_fused kernel self.topk_indices = self.indexer.get_indices(input, q, infer_state, layer_weight.indexer_layer_weight) q = layer_weight.q_b_proj_.mm(q) @@ -76,7 +73,7 @@ def _context_attention_kernel( self, q: torch.Tensor, kv, - infer_state: Deepseek3_2FlashAttentionStateInfo, + infer_state: Deepseek3_2InferStateInfo, layer_weight: Deepseek3_2TransformerLayerWeight, out=None, ) -> torch.Tensor: @@ -111,7 +108,7 @@ def _context_attention_kernel( def _token_attention_kernel( self, q, - infer_state: Deepseek3_2FlashAttentionStateInfo, + infer_state: Deepseek3_2InferStateInfo, layer_weight: Deepseek3_2TransformerLayerWeight, out=None, ): diff --git a/lightllm/models/deepseek3_2/layer_weights/nsa_indexer_layer_weight.py b/lightllm/models/deepseek3_2/layer_weights/nsa_indexer_layer_weight.py index 9e1337b0fa..6df1a88215 100644 --- a/lightllm/models/deepseek3_2/layer_weights/nsa_indexer_layer_weight.py +++ b/lightllm/models/deepseek3_2/layer_weights/nsa_indexer_layer_weight.py @@ -11,40 +11,47 @@ def __init__(self, layer_num, data_type, network_config, quant_cfg): super().__init__(layer_num, data_type, network_config, quant_cfg) return + @override + def _parse_config(self): + self.q_lora_rank = self.network_config_["q_lora_rank"] + self.index_n_heads = self.network_config_["index_n_heads"] + self.index_head_dim = self.network_config_["index_head_dim"] + self.hidden_size = self.network_config_["hidden_size"] + @override def _init_weight(self): prefix = f"model.layers.{self.layer_num_}.self_attn.indexer" self.wq_b_proj_ = ROWMMWeight( - weight_name=f"{prefix}.wq_b.weight", + in_dim=self.q_lora_rank, + out_dims=[self.index_n_heads * self.index_head_dim], + weight_names=f"{prefix}.wq_b.weight", data_type=self.data_type_, - quant_cfg=self.quant_cfg, - layer_num=self.layer_num_, - name="wq_b", + quant_method=None, tp_rank=0, tp_world_size=1, ) self.wk_proj_ = ROWMMWeight( - weight_name=f"{prefix}.wk.weight", + in_dim=self.hidden_size, + out_dims=[self.index_head_dim], + weight_names=f"{prefix}.wk.weight", data_type=self.data_type_, - quant_cfg=self.quant_cfg, - layer_num=self.layer_num_, - name="wk", + quant_method=None, tp_rank=0, tp_world_size=1, ) self.k_norm_ = LayerNormWeight( - dim=self.network_config_["index_head_dim"], + dim=self.index_head_dim, weight_name=f"{prefix}.k_norm.weight", - data_type=torch.float32, + data_type=self.data_type_, bias_name=f"{prefix}.k_norm.bias", ) self.weights_proj_ = ROWMMWeight( - weight_name=f"{prefix}.weights_proj.weight", + in_dim=self.hidden_size, + out_dims=[self.index_n_heads], + weight_names=f"{prefix}.weights_proj.weight", data_type=self.data_type_, - quant_cfg=None, - layer_num=self.layer_num_, - name="weights_proj", + quant_method=None, tp_rank=0, tp_world_size=1, ) diff --git a/lightllm/models/deepseek3_2/mem_manager.py b/lightllm/models/deepseek3_2/mem_manager.py index 8017a84adc..dc78f1de4c 100644 --- a/lightllm/models/deepseek3_2/mem_manager.py +++ b/lightllm/models/deepseek3_2/mem_manager.py @@ -6,12 +6,6 @@ class IndexerKSBuffer: - """Lightweight buffer holder for NSA indexer keys+scales. - - Shares token indices with the parent MemoryManager — does NOT have its - own allocator. Only stores the per-layer kv_buffer tensor. - """ - def __init__(self, size: int, head_num: int, head_dim: int, layer_num: int, dtype=torch.uint8): self.kv_buffer = torch.empty((layer_num, size + 1, head_num, head_dim), dtype=dtype, device="cuda") diff --git a/lightllm/models/deepseek3_2/model.py b/lightllm/models/deepseek3_2/model.py index d25cbd3783..f907b0bed6 100644 --- a/lightllm/models/deepseek3_2/model.py +++ b/lightllm/models/deepseek3_2/model.py @@ -1,11 +1,107 @@ +import copy +import json +import logging + from lightllm.models.registry import ModelRegistry from lightllm.models.deepseek2.model import Deepseek2TpPartModel from lightllm.models.deepseek3_2.layer_weights.transformer_layer_weight import Deepseek3_2TransformerLayerWeight from lightllm.models.deepseek3_2.layer_infer.transformer_layer_infer import Deepseek3_2TransformerLayerInfer from lightllm.utils.envs_utils import get_env_start_args -from lightllm.models.deepseek3_2.infer_struct import Deepseek3_2FlashAttentionStateInfo +from lightllm.models.deepseek3_2.infer_struct import Deepseek3_2InferStateInfo from lightllm.models.deepseek3_2.mem_manager import Deepseek3_2MemoryManager, Deepseek3_2FP8KVMemoryManager +_logger = logging.getLogger(__name__) + + +class DeepSeekV32Tokenizer: + """Tokenizer wrapper for DeepSeek-V3.2 that uses the Python-based + encoding_dsv32 module instead of Jinja chat templates. + + DeepSeek-V3.2's tokenizer_config.json does not ship with a Jinja chat + template, so ``apply_chat_template`` would fail without either a manually + supplied ``--chat_template`` file or this wrapper. Activate it with + ``--tokenizer_mode deepseek_v32``. + """ + + def __init__(self, tokenizer): + self.tokenizer = tokenizer + # Cache added vocabulary for performance (HuggingFace can be slow). + self._added_vocab = None + + # ------------------------------------------------------------------ + # Attribute delegation – everything not overridden goes to the inner + # tokenizer so that encode/decode/vocab_size/eos_token_id/… all work. + # ------------------------------------------------------------------ + def __getattr__(self, name): + return getattr(self.tokenizer, name) + + def get_added_vocab(self): + if self._added_vocab is None: + self._added_vocab = self.tokenizer.get_added_vocab() + return self._added_vocab + + # ------------------------------------------------------------------ + # Core override: route apply_chat_template through encode_messages. + # ------------------------------------------------------------------ + def apply_chat_template( + self, + conversation=None, + messages=None, + tools=None, + tokenize=False, + add_generation_prompt=True, + thinking=None, + **kwargs, + ): + from lightllm.models.deepseek3_2.encoding_dsv32 import encode_messages, render_tools + + msgs = conversation if conversation is not None else messages + if msgs is None: + raise ValueError("Either 'conversation' or 'messages' must be provided") + + # Deep copy to avoid mutating the caller's messages. + msgs = copy.deepcopy(msgs) + + # Determine thinking mode. + thinking_mode = "thinking" if thinking else "chat" + + # Inject tools into the first system message (or create one) so that + # encode_messages / render_message picks them up. + if tools: + # build_prompt passes tools as bare function dicts: + # [{"name": "f", "description": "...", "parameters": {...}}] + # encoding_dsv32's render_message expects OpenAI wrapper format: + # [{"type": "function", "function": {...}}] + wrapped_tools = [] + for t in tools: + if "function" in t: + wrapped_tools.append(t) + else: + wrapped_tools.append({"type": "function", "function": t}) + + injected = False + for msg in msgs: + if msg.get("role") == "system": + existing = msg.get("tools") or [] + msg["tools"] = existing + wrapped_tools + injected = True + break + + if not injected: + # Prepend a system message that carries the tools. + msgs.insert(0, {"role": "system", "content": "", "tools": wrapped_tools}) + + prompt = encode_messages( + msgs, + thinking_mode=thinking_mode, + drop_thinking=kwargs.get("drop_thinking", True), + add_default_bos_token=kwargs.get("add_default_bos_token", True), + ) + + if tokenize: + return self.tokenizer.encode(prompt, add_special_tokens=False) + return prompt + @ModelRegistry(["deepseek_v32"]) class Deepseek3_2TpPartModel(Deepseek2TpPartModel): @@ -16,7 +112,7 @@ class Deepseek3_2TpPartModel(Deepseek2TpPartModel): transformer_layer_infer_class = Deepseek3_2TransformerLayerInfer # infer state class - infer_state_class = Deepseek3_2FlashAttentionStateInfo + infer_state_class = Deepseek3_2InferStateInfo def __init__(self, kvargs): super().__init__(kvargs) @@ -24,11 +120,11 @@ def __init__(self, kvargs): return def _init_inferstate_cls(self): - self.infer_state_class = Deepseek3_2FlashAttentionStateInfo + self.infer_state_class = Deepseek3_2InferStateInfo def _init_mem_manager(self): manager_class = Deepseek3_2MemoryManager - if "triton_fp8kv" in self.mode: + if get_env_start_args().llm_kv_type == "fp8kv": manager_class = Deepseek3_2FP8KVMemoryManager # mtp 模式下需要在mem manger上扩展draft model使用的layer diff --git a/lightllm/models/deepseek3_2/triton_kernel/destindex_copy_indexer_ks.py b/lightllm/models/deepseek3_2/triton_kernel/destindex_copy_indexer_ks.py index 8faf3cdea4..a345bd1e20 100644 --- a/lightllm/models/deepseek3_2/triton_kernel/destindex_copy_indexer_ks.py +++ b/lightllm/models/deepseek3_2/triton_kernel/destindex_copy_indexer_ks.py @@ -19,16 +19,6 @@ def _fwd_kernel_destindex_copy_indexer_ks( stride_o_d, BLOCK_DMODEL: tl.constexpr, ): - """ - Triton kernel to copy FP8 K values and their scales to an indexed output buffer. - - This kernel reads FP8 key values (128 dims) and their float32 scale values, - then writes them to a compact buffer format where each entry contains: - - Bytes 0-127: FP8 key values (128 bytes) - - Bytes 128-131: Float32 scale (4 bytes) - - The destination location for each source element is specified by DestLoc. - """ cur_index = tl.program_id(0) offs_d = tl.arange(0, BLOCK_DMODEL) @@ -64,37 +54,6 @@ def _fwd_kernel_destindex_copy_indexer_ks( def destindex_copy_indexer_ks( K_fp8: torch.Tensor, K_scale: torch.Tensor, DestLoc: torch.Tensor, O_buffer: torch.Tensor ): - """ - Copy FP8-quantized key values and their scales to indexed locations in a buffer. - - This function is used in the DeepSeek-V3.2 NSA (Neighbor-aware Sparse Attention) - mechanism to store compressed key representations in a memory buffer. Each key - is stored with its FP8 representation (128 bytes) followed by its float32 scale - (4 bytes), for a total of 132 bytes per key. - - Args: - K_fp8: [q_seq_len, 128] torch.fp8_e4m3fn - FP8-quantized key values - K_scale: [q_seq_len, 1] torch.float32 - Quantization scales for each key - DestLoc: [q_seq_len] torch.int32 - Destination indices in the output buffer - O_buffer: [large_size, 1, 132] torch.uint8 - Output buffer where keys and scales will be written. - Must be a uint8 tensor to allow mixed-type storage. - Format: [:, 0, :128] = FP8 keys, [:, 0, 128:132] = float32 scales - - Returns: - None (modifies O_buffer in-place) - - Example: - >>> k_fp8 = torch.randn(50, 128).to(torch.float8_e4m3fn).cuda() - >>> k_scale = torch.randn(50, 1).cuda() - >>> dest_loc = torch.randint(0, 1024, (50,), dtype=torch.int32).cuda() - >>> o_buffer = torch.zeros(1024, 1, 132, dtype=torch.uint8).cuda() - >>> destindex_copy_indexer_ks(k_fp8, k_scale, dest_loc, o_buffer) - >>> # Now o_buffer[dest_loc] contains the packed k_fp8 and k_scale data - """ seq_len = DestLoc.shape[0] head_dim = K_fp8.shape[1] @@ -129,130 +88,3 @@ def destindex_copy_indexer_ks( num_stages=1, ) return - - -def test_destindex_copy_indexer_ks(): - """Test the destindex_copy_indexer_ks kernel""" - import torch.nn.functional as F - - print("=" * 80) - print("Testing destindex_copy_indexer_ks") - print("=" * 80) - - # Test parameters - q_seq_len = 50 - head_dim = 128 - large_size = 1024 - dtype = torch.bfloat16 - fp8_type = torch.float8_e4m3fn - - # Create random destination indices - dest_loc = torch.randint(0, large_size, (q_seq_len,), device="cuda", dtype=torch.int32).unique() - actual_seq_len = len(dest_loc) - - # Create input tensors - k_bf16 = torch.randn((actual_seq_len, head_dim), dtype=dtype, device="cuda") - - # Quantize to FP8 - k_abs_max = k_bf16.abs().max(dim=1, keepdim=True)[0].clamp(min=1e-12) - k_scale = (k_abs_max / torch.finfo(fp8_type).max).to(torch.float32) - k_fp8 = (k_bf16 / k_abs_max).clamp(torch.finfo(fp8_type).min, torch.finfo(fp8_type).max).to(fp8_type) - - # Create output buffer (as uint8 to allow reinterpretation) - o_buffer_uint8 = torch.zeros((large_size, 1, 132), dtype=torch.uint8, device="cuda") - - # Run kernel - destindex_copy_indexer_ks(k_fp8, k_scale, dest_loc, o_buffer_uint8) - - # Extract results - k_fp8_out = o_buffer_uint8[:, 0, :128].view(fp8_type) - - # Extract scale by reinterpreting 4 bytes as float32 - scale_bytes = o_buffer_uint8[:, 0, 128:132].contiguous() - k_scale_out = scale_bytes.view(-1, 4).view(torch.float32).squeeze(-1) - - # Verify results at destination locations - k_fp8_extracted = k_fp8_out[dest_loc] - k_scale_extracted = k_scale_out[dest_loc] - - # Check FP8 values match - fp8_match = torch.allclose(k_fp8_extracted.to(torch.float32), k_fp8.to(torch.float32), atol=0, rtol=0) - - # Check scales match - scale_match = torch.allclose(k_scale_extracted, k_scale.squeeze(-1), atol=1e-6, rtol=1e-5) - - # Check dequantized values - k_dequant_out = k_fp8_extracted.to(dtype) * k_scale_extracted.unsqueeze(-1) - cosine_sim = F.cosine_similarity(k_dequant_out, k_bf16, dim=-1).mean() - - print(f"Test with seq_len={actual_seq_len}, head_dim={head_dim}") - print(f" FP8 values match: {fp8_match}") - print(f" Scale values match: {scale_match}") - print(f" Cosine similarity after dequantization: {cosine_sim:.6f}") - - assert fp8_match, "FP8 values do not match!" - assert scale_match, "Scale values do not match!" - assert cosine_sim > 0.99, f"Cosine similarity too low: {cosine_sim}" - - print("✓ Basic test passed!") - print() - - # Test edge cases - print("Testing edge cases...") - - # Test with sequential indices - dest_loc_seq = torch.arange(20, device="cuda", dtype=torch.int32) - k_bf16_seq = torch.randn((20, head_dim), dtype=dtype, device="cuda") - k_abs_max_seq = k_bf16_seq.abs().max(dim=1, keepdim=True)[0].clamp(min=1e-12) - k_scale_seq = (k_abs_max_seq / torch.finfo(fp8_type).max).to(torch.float32) - k_fp8_seq = (k_bf16_seq / k_abs_max_seq).clamp(torch.finfo(fp8_type).min, torch.finfo(fp8_type).max).to(fp8_type) - - o_buffer_seq = torch.zeros((large_size, 1, 132), dtype=torch.uint8, device="cuda") - destindex_copy_indexer_ks(k_fp8_seq, k_scale_seq, dest_loc_seq, o_buffer_seq) - - k_fp8_out_seq = o_buffer_seq[:20, 0, :128].view(fp8_type) - scale_bytes_seq = o_buffer_seq[:20, 0, 128:132].contiguous() - k_scale_out_seq = scale_bytes_seq.view(-1, 4).view(torch.float32).squeeze(-1) - - fp8_match_seq = torch.allclose(k_fp8_out_seq.to(torch.float32), k_fp8_seq.to(torch.float32), atol=0, rtol=0) - scale_match_seq = torch.allclose(k_scale_out_seq, k_scale_seq.squeeze(-1), atol=1e-6, rtol=1e-5) - - print(f" Sequential indices test: FP8={fp8_match_seq}, Scale={scale_match_seq}") - assert fp8_match_seq and scale_match_seq - print("✓ Edge case tests passed!") - print() - - # Test with single element - print("Testing single element...") - dest_loc_single = torch.tensor([42], device="cuda", dtype=torch.int32) - k_bf16_single = torch.randn((1, head_dim), dtype=dtype, device="cuda") - k_abs_max_single = k_bf16_single.abs().max(dim=1, keepdim=True)[0].clamp(min=1e-12) - k_scale_single = (k_abs_max_single / torch.finfo(fp8_type).max).to(torch.float32) - k_fp8_single = ( - (k_bf16_single / k_abs_max_single).clamp(torch.finfo(fp8_type).min, torch.finfo(fp8_type).max).to(fp8_type) - ) - - o_buffer_single = torch.zeros((large_size, 1, 132), dtype=torch.uint8, device="cuda") - destindex_copy_indexer_ks(k_fp8_single, k_scale_single, dest_loc_single, o_buffer_single) - - k_fp8_out_single = o_buffer_single[42:43, 0, :128].view(fp8_type) - scale_bytes_single = o_buffer_single[42:43, 0, 128:132].contiguous() - k_scale_out_single = scale_bytes_single.view(-1, 4).view(torch.float32).squeeze(-1) - - fp8_match_single = torch.allclose( - k_fp8_out_single.to(torch.float32), k_fp8_single.to(torch.float32), atol=0, rtol=0 - ) - scale_match_single = torch.allclose(k_scale_out_single, k_scale_single.squeeze(-1), atol=1e-6, rtol=1e-5) - - print(f" Single element test: FP8={fp8_match_single}, Scale={scale_match_single}") - assert fp8_match_single and scale_match_single - print("✓ Single element test passed!") - print() - - print("=" * 80) - print("All tests passed successfully! ✓") - print("=" * 80) - - -if __name__ == "__main__": - test_destindex_copy_indexer_ks() diff --git a/lightllm/models/deepseek3_2/triton_kernel/extract_indexer_ks.py b/lightllm/models/deepseek3_2/triton_kernel/extract_indexer_ks.py index eb22fbb8f7..48bc34ad6e 100644 --- a/lightllm/models/deepseek3_2/triton_kernel/extract_indexer_ks.py +++ b/lightllm/models/deepseek3_2/triton_kernel/extract_indexer_ks.py @@ -6,10 +6,10 @@ @triton.jit def _fwd_kernel_extract_indexer_ks( - I_buffer, # Input buffer [large_size, 1, 132] uint8 - SrcLoc, # Source indices [req_size] int32/int64 - O_fp8, # Output FP8 [req_size, 128] float8_e4m3fn - O_scale, # Output scale [req_size] float32 + I_buffer, # Input buffer [large_size, 1, 132] uint8 + SrcLoc, # Source indices [req_size] int32/int64 + O_fp8, # Output FP8 [req_size, 128] float8_e4m3fn + O_scale, # Output scale [req_size] float32 stride_i_bs, stride_i_h, stride_i_d, @@ -18,98 +18,51 @@ def _fwd_kernel_extract_indexer_ks( stride_o_scale_bs, BLOCK_DMODEL: tl.constexpr, ): - """ - Triton kernel to extract FP8 K values and their scales from an indexed buffer. - - This kernel is the inverse of destindex_copy_indexer_ks. It reads from a - compact buffer format where each entry contains: - - Bytes 0-127: FP8 key values (128 bytes) - - Bytes 128-131: Float32 scale (4 bytes) - - The source location for each output element is specified by SrcLoc. - """ cur_index = tl.program_id(0) offs_d = tl.arange(0, BLOCK_DMODEL) - - # Load source index for this thread + src_index = tl.load(SrcLoc + cur_index).to(tl.int64) - - # Load K_fp8 from I_buffer[:, 0, :128] + i_k_ptrs = I_buffer + src_index * stride_i_bs + stride_i_d * offs_d k_fp8_as_uint8 = tl.load(i_k_ptrs) - - # Convert uint8 to fp8 through bitcast + k_fp8 = k_fp8_as_uint8.to(tl.float8e4nv, bitcast=True) - - # Store K_fp8 to output + o_k_ptrs = O_fp8 + cur_index * stride_o_fp8_bs + stride_o_fp8_d * offs_d tl.store(o_k_ptrs, k_fp8) - - # Load K_scale from I_buffer[:, 0, 128:132] (4 bytes for float32) - # Load 4 bytes and reconstruct float32 (little-endian) + i_scale_base_ptr = I_buffer + src_index * stride_i_bs + BLOCK_DMODEL * stride_i_d - - # Load 4 bytes individually and combine them into uint32 + byte0 = tl.load(i_scale_base_ptr + 0 * stride_i_d).to(tl.uint32) byte1 = tl.load(i_scale_base_ptr + 1 * stride_i_d).to(tl.uint32) byte2 = tl.load(i_scale_base_ptr + 2 * stride_i_d).to(tl.uint32) byte3 = tl.load(i_scale_base_ptr + 3 * stride_i_d).to(tl.uint32) - - # Combine bytes into uint32 (little-endian: byte0 is LSB) + scale_as_uint32 = byte0 | (byte1 << 8) | (byte2 << 16) | (byte3 << 24) - - # Bitcast uint32 to float32 + k_scale = scale_as_uint32.to(tl.float32, bitcast=True) - - # Store scale to output + o_scale_ptr = O_scale + cur_index * stride_o_scale_bs tl.store(o_scale_ptr, k_scale) - + return @torch.no_grad() def extract_indexer_ks(I_buffer: torch.Tensor, SrcLoc: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - """ - Extract FP8-quantized key values and their scales from indexed locations in a buffer. - - This function is the inverse operation of destindex_copy_indexer_ks. It's used in - the DeepSeek-V3.2 NSA (Neighbor-aware Sparse Attention) mechanism to retrieve - compressed key representations from a memory buffer. - - Args: - I_buffer: [large_size, 1, 132] torch.uint8 - Input buffer containing packed FP8 keys and float32 scales. - Format: [:, 0, :128] = FP8 keys, [:, 0, 128:132] = float32 scales - SrcLoc: [req_size] torch.int32 or torch.int64 - Source indices to extract from the input buffer - - Returns: - tuple containing: - - K_fp8: [req_size, 128] torch.float8_e4m3fn - FP8-quantized key values - - K_scale: [req_size] torch.float32 - Quantization scales for each key - - Example: - >>> i_buffer = torch.zeros(1024, 1, 132, dtype=torch.uint8).cuda() - >>> src_loc = torch.tensor([10, 20, 30], dtype=torch.int32).cuda() - >>> k_fp8, k_scale = extract_indexer_ks(i_buffer, src_loc) - >>> # k_fp8.shape == [3, 128], k_scale.shape == [3] - """ req_size = SrcLoc.shape[0] head_dim = 128 - + assert I_buffer.dtype == torch.uint8, f"Expected I_buffer dtype=uint8, got {I_buffer.dtype}" assert I_buffer.shape[2] == 132, f"Expected I_buffer last dim=132, got {I_buffer.shape[2]}" - + # Allocate output tensors O_fp8 = torch.empty((req_size, head_dim), dtype=torch.float8_e4m3fn, device=I_buffer.device) O_scale = torch.empty((req_size,), dtype=torch.float32, device=I_buffer.device) - + grid = (req_size,) num_warps = 1 - + _fwd_kernel_extract_indexer_ks[grid]( I_buffer, SrcLoc, @@ -125,185 +78,5 @@ def extract_indexer_ks(I_buffer: torch.Tensor, SrcLoc: torch.Tensor) -> tuple[to num_warps=num_warps, num_stages=1, ) - - return O_fp8, O_scale - - -def test_extract_indexer_ks(): - """Test the extract_indexer_ks kernel against the copy kernel""" - import torch.nn.functional as F - from lightllm.models.deepseek3_2.triton_kernel.destindex_copy_indexer_ks import destindex_copy_indexer_ks - - print("=" * 80) - print("Testing extract_indexer_ks") - print("=" * 80) - - # Test parameters - q_seq_len = 50 - head_dim = 128 - large_size = 1024 - dtype = torch.bfloat16 - fp8_type = torch.float8_e4m3fn - - # Create random indices for writing - write_indices = torch.randint(0, large_size, (q_seq_len,), device="cuda", dtype=torch.int32).unique() - actual_seq_len = len(write_indices) - - # Create input tensors - k_bf16_original = torch.randn((actual_seq_len, head_dim), dtype=dtype, device="cuda") - - # Quantize to FP8 - k_abs_max = k_bf16_original.abs().max(dim=1, keepdim=True)[0].clamp(min=1e-12) - k_scale_original = (k_abs_max / torch.finfo(fp8_type).max).to(torch.float32) - k_fp8_original = (k_bf16_original / k_abs_max).clamp( - torch.finfo(fp8_type).min, torch.finfo(fp8_type).max - ).to(fp8_type) - - # Create buffer and write data using destindex_copy_indexer_ks - buffer = torch.zeros((large_size, 1, 132), dtype=torch.uint8, device="cuda") - destindex_copy_indexer_ks(k_fp8_original, k_scale_original, write_indices, buffer) - - # Now extract the data back using extract_indexer_ks - k_fp8_extracted, k_scale_extracted = extract_indexer_ks(buffer, write_indices) - - # Verify FP8 values match - fp8_match = torch.allclose( - k_fp8_extracted.to(torch.float32), - k_fp8_original.to(torch.float32), - atol=0, rtol=0 - ) - - # Verify scales match - scale_match = torch.allclose( - k_scale_extracted, - k_scale_original.squeeze(-1), - atol=1e-6, rtol=1e-5 - ) - - # Check dequantized values - k_dequant_extracted = k_fp8_extracted.to(dtype) * k_scale_extracted.unsqueeze(-1) - cosine_sim = F.cosine_similarity(k_dequant_extracted, k_bf16_original, dim=-1).mean() - - print(f"Test with seq_len={actual_seq_len}, head_dim={head_dim}") - print(f" FP8 values match: {fp8_match}") - print(f" Scale values match: {scale_match}") - print(f" Cosine similarity after dequantization: {cosine_sim:.6f}") - - assert fp8_match, "FP8 values do not match!" - assert scale_match, "Scale values do not match!" - assert cosine_sim > 0.99, f"Cosine similarity too low: {cosine_sim}" - - print("✓ Basic test passed!") - print() - - # Test with sequential indices - print("Testing sequential indices...") - write_indices_seq = torch.arange(20, device="cuda", dtype=torch.int32) - k_bf16_seq = torch.randn((20, head_dim), dtype=dtype, device="cuda") - k_abs_max_seq = k_bf16_seq.abs().max(dim=1, keepdim=True)[0].clamp(min=1e-12) - k_scale_seq = (k_abs_max_seq / torch.finfo(fp8_type).max).to(torch.float32) - k_fp8_seq = (k_bf16_seq / k_abs_max_seq).clamp( - torch.finfo(fp8_type).min, torch.finfo(fp8_type).max - ).to(fp8_type) - - buffer_seq = torch.zeros((large_size, 1, 132), dtype=torch.uint8, device="cuda") - destindex_copy_indexer_ks(k_fp8_seq, k_scale_seq, write_indices_seq, buffer_seq) - k_fp8_ext_seq, k_scale_ext_seq = extract_indexer_ks(buffer_seq, write_indices_seq) - - fp8_match_seq = torch.allclose( - k_fp8_ext_seq.to(torch.float32), - k_fp8_seq.to(torch.float32), - atol=0, rtol=0 - ) - scale_match_seq = torch.allclose( - k_scale_ext_seq, - k_scale_seq.squeeze(-1), - atol=1e-6, rtol=1e-5 - ) - - print(f" Sequential indices: FP8={fp8_match_seq}, Scale={scale_match_seq}") - assert fp8_match_seq and scale_match_seq - print("✓ Sequential test passed!") - print() - - # Test with single element - print("Testing single element...") - write_idx_single = torch.tensor([42], device="cuda", dtype=torch.int32) - k_bf16_single = torch.randn((1, head_dim), dtype=dtype, device="cuda") - k_abs_max_single = k_bf16_single.abs().max(dim=1, keepdim=True)[0].clamp(min=1e-12) - k_scale_single = (k_abs_max_single / torch.finfo(fp8_type).max).to(torch.float32) - k_fp8_single = (k_bf16_single / k_abs_max_single).clamp( - torch.finfo(fp8_type).min, torch.finfo(fp8_type).max - ).to(fp8_type) - - buffer_single = torch.zeros((large_size, 1, 132), dtype=torch.uint8, device="cuda") - destindex_copy_indexer_ks(k_fp8_single, k_scale_single, write_idx_single, buffer_single) - k_fp8_ext_single, k_scale_ext_single = extract_indexer_ks(buffer_single, write_idx_single) - - fp8_match_single = torch.allclose( - k_fp8_ext_single.to(torch.float32), - k_fp8_single.to(torch.float32), - atol=0, rtol=0 - ) - scale_match_single = torch.allclose( - k_scale_ext_single, - k_scale_single.squeeze(-1), - atol=1e-6, rtol=1e-5 - ) - - print(f" Single element: FP8={fp8_match_single}, Scale={scale_match_single}") - assert fp8_match_single and scale_match_single - print("✓ Single element test passed!") - print() - - # Test with larger batch to check performance characteristics - print("Testing larger batch (performance check)...") - write_indices_large = torch.randint(0, large_size * 10, (500,), device="cuda", dtype=torch.int32).unique() - actual_large_len = len(write_indices_large) - k_bf16_large = torch.randn((actual_large_len, head_dim), dtype=dtype, device="cuda") - k_abs_max_large = k_bf16_large.abs().max(dim=1, keepdim=True)[0].clamp(min=1e-12) - k_scale_large = (k_abs_max_large / torch.finfo(fp8_type).max).to(torch.float32) - k_fp8_large = (k_bf16_large / k_abs_max_large).clamp( - torch.finfo(fp8_type).min, torch.finfo(fp8_type).max - ).to(fp8_type) - - buffer_large = torch.zeros((large_size * 10, 1, 132), dtype=torch.uint8, device="cuda") - destindex_copy_indexer_ks(k_fp8_large, k_scale_large, write_indices_large, buffer_large) - - # Warm up - for _ in range(3): - _ = extract_indexer_ks(buffer_large, write_indices_large) - - # Time it - torch.cuda.synchronize() - import time - start = time.time() - for _ in range(100): - k_fp8_ext_large, k_scale_ext_large = extract_indexer_ks(buffer_large, write_indices_large) - torch.cuda.synchronize() - elapsed = time.time() - start - - fp8_match_large = torch.allclose( - k_fp8_ext_large.to(torch.float32), - k_fp8_large.to(torch.float32), - atol=0, rtol=0 - ) - scale_match_large = torch.allclose( - k_scale_ext_large, - k_scale_large.squeeze(-1), - atol=1e-6, rtol=1e-5 - ) - - print(f" Large batch (size={actual_large_len}): FP8={fp8_match_large}, Scale={scale_match_large}") - print(f" Average time per call: {elapsed/100*1000:.3f} ms") - assert fp8_match_large and scale_match_large - print("✓ Large batch test passed!") - print() - - print("=" * 80) - print("All tests passed successfully! ✓") - print("=" * 80) - -if __name__ == "__main__": - test_extract_indexer_ks() + return O_fp8, O_scale diff --git a/lightllm/models/deepseek3_2/triton_kernel/fp8_mqa_logits.py b/lightllm/models/deepseek3_2/triton_kernel/fp8_mqa_logits.py index e8f1bbfa21..1c1f72b7d7 100644 --- a/lightllm/models/deepseek3_2/triton_kernel/fp8_mqa_logits.py +++ b/lightllm/models/deepseek3_2/triton_kernel/fp8_mqa_logits.py @@ -1,4 +1,3 @@ -# import triton import triton.language as tl import torch @@ -35,68 +34,44 @@ def _fp8_paged_mqa_logits_kernel( pid_m = tl.program_id(0) pid_n = tl.program_id(1) - # Compute the range of seq positions this block handles start_m = pid_m * BLOCK_SIZE_M start_n = pid_n * BLOCK_SIZE_N - # Offset arrays for this block offs_m = start_m + tl.arange(0, BLOCK_SIZE_M) offs_n = start_n + tl.arange(0, BLOCK_SIZE_N) - # Initialize accumulator for logits logits = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - # Create masks mask_m = offs_m < seq_len mask_n = offs_n < seq_len_kv - # Load mem_indices for the KV positions mem_indices = tl.load(MemIndex_ptr + offs_n, mask=mask_n, other=0) - # Load scales for K scales = tl.load(KVScale_ptr + mem_indices, mask=mask_n, other=1.0) - # Loop over all heads for h in range(num_heads): - # Load weights for this head weights = tl.load(Weights_ptr + offs_m * stride_w_seq + h * stride_w_head, mask=mask_m, other=0.0) - - # Initialize score accumulator for this head score = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - # Loop over head_dim in blocks for d_block in range(tl.cdiv(head_dim, BLOCK_SIZE_D)): d_start = d_block * BLOCK_SIZE_D offs_d = d_start + tl.arange(0, BLOCK_SIZE_D) mask_d = offs_d < head_dim - # Load Q for this head and dimension block - # Q shape: (seq_len, num_heads, head_dim) q_ptrs = Q_ptr + offs_m[:, None] * stride_q_seq + h * stride_q_head + offs_d[None, :] * stride_q_dim mask_q = (offs_m[:, None] < seq_len) & mask_d[None, :] q = tl.load(q_ptrs, mask=mask_q, other=0.0).to(tl.float32) - # Load K for this dimension block - # KV shape: (pool_size, head_dim) as FP8 data k_ptrs = KV_ptr + mem_indices[:, None] * stride_kv_pool + offs_d[None, :] * stride_kv_dim mask_k = mask_n[:, None] & mask_d[None, :] k = tl.load(k_ptrs, mask=mask_k, other=0.0).to(tl.float32) - # Apply scale to K (scale is per-row of K) k = k * scales[:, None] - # Compute partial dot product: q @ k.T - # q: (BLOCK_SIZE_M, BLOCK_SIZE_D), k: (BLOCK_SIZE_N, BLOCK_SIZE_D) - # score: (BLOCK_SIZE_M, BLOCK_SIZE_N) score += tl.dot(q, tl.trans(k)) - - # Apply ReLU to score score = tl.maximum(score, 0.0) - - # Multiply by weights and accumulate to logits logits += score * weights[:, None] - # Apply mask based on cu_seqlen_ks and cu_seqlen_ke mask_ks = tl.load(CuSeqlenKs_ptr + offs_m, mask=mask_m, other=0) mask_ke = tl.load(CuSeqlenKe_ptr + offs_m, mask=mask_m, other=seq_len_kv) @@ -104,7 +79,6 @@ def _fp8_paged_mqa_logits_kernel( mask_hi = offs_n[None, :] < mask_ke[:, None] mask_valid = mask_lo & mask_hi & mask_m[:, None] & mask_n[None, :] - # Apply mask (-inf for masked positions) logits = tl.where(mask_valid, logits, float("-inf")) # Store output diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 4b92298b6b..121c272973 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -92,9 +92,12 @@ def make_argument_parser() -> argparse.ArgumentParser: "--tokenizer_mode", type=str, default="fast", - help="""tokenizer load mode, can be slow, fast or auto, slow mode load fast but run slow, - slow mode is good for debug and test, fast mode get best performance, auto mode will - try to use fast mode, if failed will use slow mode""", + help="""tokenizer load mode, can be slow, fast, auto, or deepseek_v32. + slow mode load fast but run slow, good for debug and test. + fast mode get best performance. + auto mode will try to use fast mode, if failed will use slow mode. + deepseek_v32 mode wraps the tokenizer with Python-based DSML chat + template encoding for DeepSeek-V3.2 models (no --chat_template needed).""", ) parser.add_argument( "--load_way", diff --git a/lightllm/server/api_openai.py b/lightllm/server/api_openai.py index b65f7e9cc5..11e24612b0 100644 --- a/lightllm/server/api_openai.py +++ b/lightllm/server/api_openai.py @@ -464,39 +464,6 @@ async def stream_results() -> AsyncGenerator[bytes, None]: yield ("data: " + json.dumps(stream_resp.dict(), ensure_ascii=False) + "\n\n").encode("utf-8") # Additional usage chunk - # Finalize any pending tool calls (e.g., DSML format last invoke) - if request.tool_choice != "none" and request.tools and parser_dict: - for _idx, _parser in parser_dict.items(): - _, finalize_calls = _parser.finalize_stream() - history_tool_calls_cnt = _get_history_tool_calls_cnt(request) - for call_item in finalize_calls: - if call_item.name: - tool_call_id = _process_tool_call_id(tool_parser, call_item, history_tool_calls_cnt) - function_name = call_item.name - else: - tool_call_id = None - function_name = None - tool_call = ToolCall( - id=tool_call_id, - index=getattr(call_item, "tool_index", None), - function=FunctionResponse( - name=function_name, - arguments=call_item.parameters, - ), - ) - choice_data = ChatCompletionStreamResponseChoice( - index=0, - delta=DeltaMessage(role="assistant", tool_calls=[tool_call]), - finish_reason="tool_calls", - ) - chunk = ChatCompletionStreamResponse( - id=group_request_id, - created=created_time, - choices=[choice_data], - model=request.model, - ) - yield f"data: {chunk.model_dump_json()}\n\n" - if request.stream_options and request.stream_options.include_usage: usage = UsageInfo( prompt_tokens=prompt_tokens, diff --git a/lightllm/server/build_prompt.py b/lightllm/server/build_prompt.py index d77184863f..f770459a55 100644 --- a/lightllm/server/build_prompt.py +++ b/lightllm/server/build_prompt.py @@ -1,28 +1,11 @@ -import json -import os - tokenizer = None -_model_type = None def init_tokenizer(args): - global tokenizer, _model_type + global tokenizer from lightllm.server.tokenizer import get_tokenizer tokenizer = get_tokenizer(args.model_dir, args.tokenizer_mode, trust_remote_code=args.trust_remote_code) - - # Detect model type for specialized encoding (e.g. DeepSeek-V3.2) - config_path = os.path.join(args.model_dir, "config.json") - if os.path.exists(config_path): - with open(config_path, "r", encoding="utf-8") as f: - model_config = json.load(f) - _model_type = model_config.get("model_type", None) - # Check architectures as fallback - if _model_type is None: - archs = model_config.get("architectures", []) - if any("DeepseekV32" in a for a in archs): - _model_type = "deepseek_v32" - chat_path = args.chat_template if chat_path is not None: with open(chat_path, "r", encoding="utf-8") as f: @@ -31,14 +14,9 @@ def init_tokenizer(args): async def build_prompt(request, tools) -> str: - global tokenizer, _model_type + global tokenizer # pydantic格式转成dict, 否则,当根据tokenizer_config.json拼template时,Jinja判断无法识别 messages = [m.model_dump(by_alias=True, exclude_none=True) for m in request.messages] - - # Use DeepSeek-V3.2 native encoding when applicable - if _model_type == "deepseek_v32": - return _build_prompt_dsv32(messages, tools, request) - kwargs = {"conversation": messages} if request.character_settings: kwargs["character_settings"] = request.character_settings @@ -62,27 +40,3 @@ async def build_prompt(request, tools) -> str: tools=tools, ) return input_str - - -def _build_prompt_dsv32(messages, tools, request): - from lightllm.server.encoding_dsv32 import encode_messages - - # Inject tools into system message if present - if tools is not None and len(tools) > 0: - wrapped_tools = [t if "function" in t else {"function": t} for t in tools] - if messages and messages[0].get("role") == "system": - messages[0]["tools"] = wrapped_tools - else: - messages.insert(0, {"role": "system", "tools": wrapped_tools}) - - # Determine thinking mode from request - thinking = False - if request.chat_template_kwargs: - thinking = request.chat_template_kwargs.get("thinking", False) or request.chat_template_kwargs.get( - "enable_thinking", False - ) - - thinking_mode = "thinking" if thinking else "chat" - drop_thinking = messages[-1]["role"] == "user" if messages else True - - return encode_messages(messages, thinking_mode=thinking_mode, drop_thinking=drop_thinking) diff --git a/lightllm/server/core/objs/sampling_params.py b/lightllm/server/core/objs/sampling_params.py index f2c5177e41..b9ad314dd9 100644 --- a/lightllm/server/core/objs/sampling_params.py +++ b/lightllm/server/core/objs/sampling_params.py @@ -332,34 +332,27 @@ class SamplingParams(ctypes.Structure): _top_p: float = 1.0 _top_k: int = -1 # -1 is for all - @staticmethod - def _get(kwargs, key, default): - """Get value from kwargs, falling back to default when value is None or missing.""" - val = kwargs.get(key) - return val if val is not None else default - def init(self, tokenizer, **kwargs): super().__init__() - _get = SamplingParams._get - self.best_of = _get(kwargs, "best_of", 1) - self.n = _get(kwargs, "n", self.best_of) - self.do_sample = _get(kwargs, "do_sample", SamplingParams._do_sample) - self.presence_penalty = _get(kwargs, "presence_penalty", SamplingParams._presence_penalty) - self.frequency_penalty = _get(kwargs, "frequency_penalty", SamplingParams._frequency_penalty) - self.repetition_penalty = _get(kwargs, "repetition_penalty", SamplingParams._repetition_penalty) - self.temperature = _get(kwargs, "temperature", SamplingParams._temperature) - self.top_p = _get(kwargs, "top_p", SamplingParams._top_p) - self.top_k = _get(kwargs, "top_k", SamplingParams._top_k) - self.ignore_eos = _get(kwargs, "ignore_eos", False) - self.image_max_patch_num = _get(kwargs, "image_max_patch_num", -1) - self.max_new_tokens = _get(kwargs, "max_new_tokens", 16) - self.min_new_tokens = _get(kwargs, "min_new_tokens", 1) - self.input_penalty = _get(kwargs, "input_penalty", DEFAULT_INPUT_PENALTY) - self.group_request_id = _get(kwargs, "group_request_id", -1) - self.suggested_dp_index = _get(kwargs, "suggested_dp_index", -1) - - self.skip_special_tokens = _get(kwargs, "skip_special_tokens", SKIP_SPECIAL_TOKENS) - self.disable_prompt_cache = _get(kwargs, "disable_prompt_cache", False) + self.best_of = kwargs.get("best_of", 1) + self.n = kwargs.get("n", self.best_of) + self.do_sample = kwargs.get("do_sample", SamplingParams._do_sample) + self.presence_penalty = kwargs.get("presence_penalty", SamplingParams._presence_penalty) + self.frequency_penalty = kwargs.get("frequency_penalty", SamplingParams._frequency_penalty) + self.repetition_penalty = kwargs.get("repetition_penalty", SamplingParams._repetition_penalty) + self.temperature = kwargs.get("temperature", SamplingParams._temperature) + self.top_p = kwargs.get("top_p", SamplingParams._top_p) + self.top_k = kwargs.get("top_k", SamplingParams._top_k) + self.ignore_eos = kwargs.get("ignore_eos", False) + self.image_max_patch_num = kwargs.get("image_max_patch_num", -1) + self.max_new_tokens = kwargs.get("max_new_tokens", 16) + self.min_new_tokens = kwargs.get("min_new_tokens", 1) + self.input_penalty = kwargs.get("input_penalty", DEFAULT_INPUT_PENALTY) + self.group_request_id = kwargs.get("group_request_id", -1) + self.suggested_dp_index = kwargs.get("suggested_dp_index", -1) + + self.skip_special_tokens = kwargs.get("skip_special_tokens", SKIP_SPECIAL_TOKENS) + self.disable_prompt_cache = kwargs.get("disable_prompt_cache", False) self.add_special_tokens = kwargs.get("add_special_tokens", True) self.add_spaces_between_special_tokens = kwargs.get("add_spaces_between_special_tokens", True) diff --git a/lightllm/server/function_call_parser.py b/lightllm/server/function_call_parser.py index c3faf21e78..3a8fddf744 100644 --- a/lightllm/server/function_call_parser.py +++ b/lightllm/server/function_call_parser.py @@ -1453,27 +1453,26 @@ def parse_streaming_increment(self, new_text: str, tools: List[Tool]) -> Streami class DeepSeekV32Detector(BaseFormatDetector): """ - Detector for DeepSeek V3.2 model function call format (DSML). - - DeepSeek V3.2 uses a new DSML (DeepSeek Markup Language) format for tool calls, - which is XML-like rather than JSON-based. + Detector for DeepSeek V3.2 model function call format using DSML + (DeepSeek Markup Language). Format Structure: ``` <|DSML|function_calls> <|DSML|invoke name="get_weather"> - <|DSML|parameter name="location" string="true">杭州 - <|DSML|parameter name="date" string="true">2024-01-16 - <|DSML|invoke name="get_weather"> - <|DSML|parameter name="location" string="true">北京 - <|DSML|parameter name="date" string="true">2024-01-16 + <|DSML|parameter name="location" string="true">Hangzhou + <|DSML|parameter name="date" string="true">2024-01-16 + + ``` Key Components: - - Tool Calls Section: Starts with `<|DSML|function_calls>` - - Individual Invoke: `<|DSML|invoke name="function_name">` - - Parameters: `<|DSML|parameter name="param_name" string="true">value` - - Parameter types are inferred from the tool schema for proper JSON serialization + - Function Calls Block: `<|DSML|function_calls>` ... `` + - Individual Invocation: `<|DSML|invoke name="func">` ... `` + - Parameters: `<|DSML|parameter name="key" string="true|false">value` + - string="true": value is plain text (will be JSON-escaped) + - string="false": value is JSON (numbers, booleans, arrays, objects) + - Supports multiple parallel tool calls Reference: https://huggingface.co/deepseek-ai/DeepSeek-V3.2 """ @@ -1481,333 +1480,132 @@ class DeepSeekV32Detector(BaseFormatDetector): def __init__(self): super().__init__() self.dsml_token = "|DSML|" - self.bot_token = "<|DSML|function_calls>" - self.eot_token = "" # DSML format has no explicit end token - self.invoke_prefix = '<|DSML|invoke name="' - self.parameter_prefix = '<|DSML|parameter name="' - - # Regex for complete parsing + self.bot_token = f"<{self.dsml_token}function_calls>" + self.eot_token = f"" + self.invoke_start_prefix = f"<{self.dsml_token}invoke" + self.invoke_end_token = f"" + self.param_end_token = f"" + + # Regex for complete invoke extraction + _de = re.escape(self.dsml_token) self.invoke_regex = re.compile( - r'<|DSML|invoke\s+name="([^"]+)"\s*>(.*?)(?=<|DSML|invoke|$)', + rf'<{_de}invoke\s+name="([^"]+)"\s*>(.*?)', re.DOTALL, ) - # Captures: (param_name, is_string, value) - self.parameter_regex = re.compile( - r'<|DSML|parameter\s+name="([^"]+)"\s+string="(true|false)"\s*>(.*?)(?=<|DSML|parameter|<|DSML|invoke|$)', + # Regex for parameter extraction + self.param_regex = re.compile( + rf'<{_de}parameter\s+name="([^"]+)"\s+string="(true|false)"\s*>(.*?)', + re.DOTALL, + ) + # Regex for partial invoke (name known, body still streaming) + self.partial_invoke_regex = re.compile( + rf'<{_de}invoke\s+name="([^"]+)"\s*>(.*)', re.DOTALL, ) - # Streaming state self._last_arguments = "" - self._current_invoke_text = "" - self._invoke_count = 0 - self._param_count_in_invoke = 0 - self._accumulated_params: Dict[str, str] = {} - self._json_started = False - self._tools_schema: Optional[Dict[str, Dict]] = None - self._tool_indices: Optional[Dict[str, int]] = None - self._current_func_name: Optional[str] = None - self._in_tool_call_sequence = False # Set True once bot_token seen + self._accumulated_params: List[tuple] = [] + self._in_function_calls = False # Track if we're inside a function_calls block def has_tool_call(self, text: str) -> bool: - """Check if the text contains a DeepSeek V3.2 DSML format tool call.""" return self.bot_token in text - def _get_param_type(self, func_name: str, param_name: str, tools: List[Tool]) -> str: - """Get the JSON Schema type of a parameter from the tool definition.""" - if self._tools_schema is None: - self._tools_schema = {} - for tool in tools: - if tool.function.name and tool.function.parameters: - props = tool.function.parameters.get("properties", {}) - self._tools_schema[tool.function.name] = props - - func_schema = self._tools_schema.get(func_name, {}) - param_schema = func_schema.get(param_name, {}) - return param_schema.get("type", "string") - - def _convert_param_value(self, value: str, is_string_attr: str, param_type: str) -> Any: - """Convert a raw parameter value string to the appropriate Python type. - - Args: - value: The raw string value from the DSML parameter tag. - is_string_attr: The "string" attribute from DSML ("true" or "false"). - If "true", the value is treated as a raw string. - If "false", the value is parsed based on param_type or JSON. - param_type: The JSON Schema type from the tool definition (fallback). - """ - value = value.strip() - if value.lower() == "null": - return None - - # Use DSML string attribute as primary signal - if is_string_attr == "true": - return value - - # string="false" - parse based on schema type or attempt JSON - param_type = param_type.lower() - if param_type in ("integer", "int"): - try: - return int(value) - except (ValueError, TypeError): - return value - elif param_type in ("number", "float"): - try: - val = float(value) - # Only coerce to int if it's actually an integer string - if "." not in value and "e" not in value.lower(): - return int(value) - return val - except (ValueError, TypeError, OverflowError): - return value - elif param_type in ("boolean", "bool"): - lower = value.lower() - if lower in ("true", "1"): - return True - elif lower in ("false", "0"): - return False + def _dsml_params_to_json(self, params: List[tuple]) -> str: + """Convert DSML parameter tuples (name, is_str, value) to a JSON arguments string.""" + args = {} + for name, is_str, value in params: + if is_str == "true": + args[name] = value else: - logger.warning(f"Unexpected boolean value: {value!r}, treating as string") - return value - elif param_type in ("object", "array"): - try: - return json.loads(value) - except json.JSONDecodeError: - return value - else: - # Unknown type with string="false" - try JSON parse, fallback to string - try: - return json.loads(value) - except json.JSONDecodeError: - return value - - def _parse_invoke_params(self, invoke_content: str, func_name: str, tools: List[Tool]) -> Dict: - """Parse all parameters from an invoke block content.""" - params = {} - for param_name, is_string_attr, param_value in self.parameter_regex.findall(invoke_content): - param_name = param_name.strip() - param_value = param_value.strip() - param_type = self._get_param_type(func_name, param_name, tools) - params[param_name] = self._convert_param_value(param_value, is_string_attr, param_type) - return params + try: + args[name] = json.loads(value) + except (json.JSONDecodeError, ValueError): + args[name] = value + return json.dumps(args, ensure_ascii=False) def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult: - """ - One-time parsing: Detects and parses DSML tool calls in the provided text. - """ - if self.bot_token not in text: - return StreamingParseResult(normal_text=text, calls=[]) - + """One-time parsing for DSML format tool calls.""" idx = text.find(self.bot_token) - normal_text = text[:idx].strip() if idx > 0 else "" - tool_section = text[idx:] + normal_text = text[:idx].strip() if idx != -1 else text + if self.bot_token not in text: + return StreamingParseResult(normal_text=normal_text, calls=[]) tool_indices = self._get_tool_indices(tools) calls = [] - try: - for func_name, invoke_content in self.invoke_regex.findall(tool_section): - func_name = func_name.strip() - if func_name not in tool_indices: - logger.warning(f"Model attempted to call undefined function: {func_name}") - continue + invoke_matches = self.invoke_regex.findall(text) + for func_name, invoke_body in invoke_matches: + if func_name not in tool_indices: + logger.warning(f"Model attempted to call undefined function: {func_name}") + continue - params = self._parse_invoke_params(invoke_content, func_name, tools) - calls.append( - ToolCallItem( - tool_index=tool_indices[func_name], - name=func_name, - parameters=json.dumps(params, ensure_ascii=False), - ) - ) - return StreamingParseResult(normal_text=normal_text, calls=calls) - except Exception as e: - logger.error(f"Error in DeepSeekV32 detect_and_parse: {e}") - return StreamingParseResult(normal_text=text) + param_matches = self.param_regex.findall(invoke_body) + args_json = self._dsml_params_to_json(param_matches) - def finalize_streaming(self, tools: List[Tool]) -> StreamingParseResult: - """Finalize the last pending tool call when generation ends (EOS). + calls.append( + ToolCallItem( + tool_index=tool_indices[func_name], + name=func_name, + parameters=args_json, + ) + ) - The DSML format has no explicit end token, so the last invoke's last - parameter may remain unconfirmed. This method should be called when - the stream ends to close any open JSON and emit remaining parameters. - """ - if not self.current_tool_name_sent or self.current_tool_id < 0: - return StreamingParseResult() + return StreamingParseResult(normal_text=normal_text, calls=calls) - calls: List[ToolCallItem] = [] + def parse_streaming_increment(self, new_text: str, tools: List[Tool]) -> StreamingParseResult: + """Streaming incremental parsing for DSML format tool calls.""" + self._buffer += new_text current_text = self._buffer - try: - # Find current invoke text - invoke_positions = [] - search_start = 0 - while True: - pos = current_text.find(self.invoke_prefix, search_start) - if pos == -1: - break - invoke_positions.append(pos) - search_start = pos + len(self.invoke_prefix) - - if self._invoke_count < len(invoke_positions): - invoke_start = invoke_positions[self._invoke_count] - invoke_text = current_text[invoke_start:] - - name_content_start = len(self.invoke_prefix) - name_end = invoke_text.find('">', name_content_start) - if name_end != -1: - func_name = invoke_text[name_content_start:name_end].strip() - invoke_body = invoke_text[name_end + 2 :] - - # Parse all remaining params (including the last unconfirmed one) - param_matches = list(self.parameter_regex.finditer(invoke_body)) - for i in range(self._param_count_in_invoke, len(param_matches)): - match = param_matches[i] - param_name = match.group(1).strip() - is_string_attr = match.group(2) - param_value = match.group(3).strip() - - param_type = self._get_param_type(func_name, param_name, tools) - converted_value = self._convert_param_value(param_value, is_string_attr, param_type) - serialized_value = json.dumps(converted_value, ensure_ascii=False) - - if not self._json_started: - json_fragment = "{" + f'"{param_name}": {serialized_value}' - self._json_started = True - else: - json_fragment = f', "{param_name}": {serialized_value}' + # Check if we're inside a function_calls block or starting one + has_tool = self.has_tool_call(current_text) or self._in_function_calls - self._accumulated_params[param_name] = converted_value - calls.append( - ToolCallItem( - tool_index=self.current_tool_id, - name=None, - parameters=json_fragment, - ) - ) - self.streamed_args_for_tool[self.current_tool_id] += json_fragment + if not has_tool: + partial_len = self._ends_with_partial_token(current_text, self.bot_token) + if partial_len: + return StreamingParseResult() - # Close the JSON object - if self._json_started: - calls.append( - ToolCallItem( - tool_index=self.current_tool_id, - name=None, - parameters="}", - ) - ) - self.streamed_args_for_tool[self.current_tool_id] += "}" - elif self.current_tool_name_sent: - calls.append( - ToolCallItem( - tool_index=self.current_tool_id, - name=None, - parameters="{}", - ) - ) - self.streamed_args_for_tool[self.current_tool_id] = "{}" - - # Update prev_tool_call_arr - if self.current_tool_id < len(self.prev_tool_call_arr): - self.prev_tool_call_arr[self.current_tool_id]["arguments"] = self._accumulated_params - - # Reset state - self._invoke_count += 1 - self.current_tool_id += 1 - self.current_tool_name_sent = False - self._json_started = False - self._accumulated_params = {} self._buffer = "" + for e_token in [self.eot_token, self.invoke_end_token]: + if e_token in new_text: + new_text = new_text.replace(e_token, "") + return StreamingParseResult(normal_text=new_text) - return StreamingParseResult(normal_text="", calls=calls) - except Exception as e: - logger.error(f"Error in DeepSeekV32 finalize_streaming: {e}") - return StreamingParseResult(normal_text="", calls=calls) - - def parse_streaming_increment(self, new_text: str, tools: List[Tool]) -> StreamingParseResult: - """ - Streaming incremental parsing for DeepSeek V3.2 DSML tool calls. + # Mark that we're inside a function_calls block + if self.has_tool_call(current_text): + self._in_function_calls = True - The DSML format streams line-by-line with invoke/parameter tokens. - We accumulate parameters and only emit JSON fragments when a parameter's - value is confirmed complete (by seeing the next parameter/invoke boundary). - """ - self._buffer += new_text - current_text = self._buffer + # Check if function_calls block has ended + if self.eot_token in current_text: + self._in_function_calls = False - # Check if we have any DSML content - if not self._in_tool_call_sequence: - if not self.has_tool_call(current_text): - # Check for partial start token - if self._ends_with_partial_token(current_text, self.bot_token): - return StreamingParseResult() - self._buffer = "" - return StreamingParseResult(normal_text=new_text) - self._in_tool_call_sequence = True - - if self._tool_indices is None: + if not hasattr(self, "_tool_indices"): self._tool_indices = self._get_tool_indices(tools) calls: List[ToolCallItem] = [] try: - # Find all invoke starts in current buffer - invoke_positions = [] - search_start = 0 - while True: - pos = current_text.find(self.invoke_prefix, search_start) - if pos == -1: - break - invoke_positions.append(pos) - search_start = pos + len(self.invoke_prefix) - - if not invoke_positions: - # Have bot_token but no invoke yet - keep buffering - return StreamingParseResult() - - # Process only the current (latest) invoke block - current_invoke_idx = self._invoke_count - if current_invoke_idx >= len(invoke_positions): - # All invokes already processed, keep buffering for new ones - return StreamingParseResult() - - invoke_start = invoke_positions[current_invoke_idx] - # Whether the current invoke is bounded by a next invoke - invoke_is_bounded = current_invoke_idx + 1 < len(invoke_positions) - if invoke_is_bounded: - invoke_end = invoke_positions[current_invoke_idx + 1] - else: - invoke_end = len(current_text) - - invoke_text = current_text[invoke_start:invoke_end] + # Try to find complete invoke blocks first + complete_invoke_match = self.invoke_regex.search(current_text) + if complete_invoke_match: + func_name = complete_invoke_match.group(1) + invoke_body = complete_invoke_match.group(2) - # Extract function name - name_start = invoke_text.find(self.invoke_prefix) - if name_start == -1: - return StreamingParseResult() - - name_content_start = name_start + len(self.invoke_prefix) - name_end = invoke_text.find('">', name_content_start) - if name_end == -1: - # Function name not complete yet - return StreamingParseResult() - - func_name = invoke_text[name_content_start:name_end].strip() + if self.current_tool_id == -1: + self.current_tool_id = 0 + self.prev_tool_call_arr = [] + self.streamed_args_for_tool = [""] + self._accumulated_params = [] - # Initialize state for this tool call - if self.current_tool_id == -1: - self.current_tool_id = 0 - self.prev_tool_call_arr = [] - self.streamed_args_for_tool = [""] + while len(self.prev_tool_call_arr) <= self.current_tool_id: + self.prev_tool_call_arr.append({}) + while len(self.streamed_args_for_tool) <= self.current_tool_id: + self.streamed_args_for_tool.append("") - while len(self.prev_tool_call_arr) <= self.current_tool_id: - self.prev_tool_call_arr.append({}) - while len(self.streamed_args_for_tool) <= self.current_tool_id: - self.streamed_args_for_tool.append("") + param_matches = self.param_regex.findall(invoke_body) + args_json = self._dsml_params_to_json(param_matches) - # Send tool name if not sent yet - if not self.current_tool_name_sent: - if func_name and func_name in self._tool_indices: + if not self.current_tool_name_sent: calls.append( ToolCallItem( tool_index=self.current_tool_id, @@ -1816,109 +1614,101 @@ def parse_streaming_increment(self, new_text: str, tools: List[Tool]) -> Streami ) ) self.current_tool_name_sent = True - self.prev_tool_call_arr[self.current_tool_id] = { - "name": func_name, - "arguments": {}, - } - self._current_func_name = func_name - self._accumulated_params = {} - self._param_count_in_invoke = 0 - self._json_started = False - return StreamingParseResult(calls=calls) - return StreamingParseResult() - - # Parse parameters from the invoke block content - invoke_body = invoke_text[name_end + 2 :] # after '">' - - # Find all parameter starts within this invoke body - param_positions = [] - ps = 0 - while True: - pp = invoke_body.find(self.parameter_prefix, ps) - if pp == -1: - break - param_positions.append(pp) - ps = pp + len(self.parameter_prefix) - - # A parameter is "confirmed" when the next parameter/invoke boundary is visible, - # meaning the parameter's value won't grow further. - # For the last parameter in the invoke body, it's only confirmed if - # the invoke itself is bounded by a next invoke. - confirmed_count = 0 - for pi in range(len(param_positions)): - if pi + 1 < len(param_positions): - confirmed_count += 1 - elif invoke_is_bounded: - confirmed_count += 1 - - # Only emit newly confirmed parameters - if confirmed_count > self._param_count_in_invoke: - param_matches = list(self.parameter_regex.finditer(invoke_body)) - for i in range(self._param_count_in_invoke, min(confirmed_count, len(param_matches))): - match = param_matches[i] - param_name = match.group(1).strip() - is_string_attr = match.group(2) - param_value = match.group(3).strip() - - param_type = self._get_param_type(func_name, param_name, tools) - converted_value = self._convert_param_value(param_value, is_string_attr, param_type) - serialized_value = json.dumps(converted_value, ensure_ascii=False) - - if not self._json_started: - json_fragment = "{" + f'"{param_name}": {serialized_value}' - self._json_started = True - else: - json_fragment = f', "{param_name}": {serialized_value}' - - self._accumulated_params[param_name] = converted_value + # Send complete arguments (or remaining diff) + sent = len(self.streamed_args_for_tool[self.current_tool_id]) + argument_diff = args_json[sent:] + if argument_diff: calls.append( ToolCallItem( tool_index=self.current_tool_id, name=None, - parameters=json_fragment, + parameters=argument_diff, ) ) - self.streamed_args_for_tool[self.current_tool_id] += json_fragment + self.streamed_args_for_tool[self.current_tool_id] += argument_diff - self._param_count_in_invoke = confirmed_count + try: + self.prev_tool_call_arr[self.current_tool_id] = { + "name": func_name, + "arguments": json.loads(args_json), + } + except json.JSONDecodeError: + self.prev_tool_call_arr[self.current_tool_id] = { + "name": func_name, + "arguments": {}, + } - # Check if next invoke has started (meaning current one is complete) - if invoke_is_bounded: - # Current invoke is complete, close JSON and advance - if self._json_started: - close_fragment = "}" - calls.append( - ToolCallItem( - tool_index=self.current_tool_id, - name=None, - parameters=close_fragment, - ) - ) - self.streamed_args_for_tool[self.current_tool_id] += close_fragment + # Remove processed invoke from buffer + invoke_end_pos = current_text.find(self.invoke_end_token, complete_invoke_match.start()) + if invoke_end_pos != -1: + self._buffer = current_text[invoke_end_pos + len(self.invoke_end_token) :] else: - calls.append( - ToolCallItem( - tool_index=self.current_tool_id, - name=None, - parameters="{}", - ) - ) - self.streamed_args_for_tool[self.current_tool_id] = "{}" + self._buffer = current_text[complete_invoke_match.end() :] - # Update prev_tool_call_arr - self.prev_tool_call_arr[self.current_tool_id]["arguments"] = self._accumulated_params - - # Advance to next invoke, prune consumed buffer content - # Reset _invoke_count to 0 since buffer positions are now relative - self._buffer = current_text[invoke_end:] - self._invoke_count = 0 self.current_tool_id += 1 - self.current_tool_name_sent = False self._last_arguments = "" - self._accumulated_params = {} - self._param_count_in_invoke = 0 - self._json_started = False + self.current_tool_name_sent = False + self._accumulated_params = [] + self.streamed_args_for_tool.append("") + + return StreamingParseResult(normal_text="", calls=calls) + + # Partial invoke: name is known but parameters are still streaming + partial_match = self.partial_invoke_regex.search(current_text) + if partial_match: + func_name = partial_match.group(1) + partial_body = partial_match.group(2) + + if self.current_tool_id == -1: + self.current_tool_id = 0 + self.prev_tool_call_arr = [] + self.streamed_args_for_tool = [""] + self._accumulated_params = [] + + while len(self.prev_tool_call_arr) <= self.current_tool_id: + self.prev_tool_call_arr.append({}) + while len(self.streamed_args_for_tool) <= self.current_tool_id: + self.streamed_args_for_tool.append("") + + if not self.current_tool_name_sent: + if func_name in self._tool_indices: + calls.append( + ToolCallItem( + tool_index=self.current_tool_id, + name=func_name, + parameters="", + ) + ) + self.current_tool_name_sent = True + self.prev_tool_call_arr[self.current_tool_id] = { + "name": func_name, + "arguments": {}, + } + else: + # Stream arguments as complete parameters are parsed + param_matches = self.param_regex.findall(partial_body) + if param_matches and len(param_matches) > len(self._accumulated_params): + self._accumulated_params = param_matches + current_args_json = self._dsml_params_to_json(param_matches) + + sent = len(self.streamed_args_for_tool[self.current_tool_id]) + argument_diff = current_args_json[sent:] + + if argument_diff: + calls.append( + ToolCallItem( + tool_index=self.current_tool_id, + name=None, + parameters=argument_diff, + ) + ) + self.streamed_args_for_tool[self.current_tool_id] += argument_diff + + try: + self.prev_tool_call_arr[self.current_tool_id]["arguments"] = json.loads(current_args_json) + except json.JSONDecodeError: + pass return StreamingParseResult(normal_text="", calls=calls) @@ -2020,19 +1810,3 @@ def parse_stream_chunk(self, chunk_text: str) -> Tuple[str, list[ToolCallItem]]: final_normal_text = sp_result.normal_text return final_normal_text, final_calls - - def finalize_stream(self) -> Tuple[str, list[ToolCallItem]]: - """Finalize streaming when generation ends. - - For detectors that lack an explicit end-of-tool-call token (like DSML), - this closes any pending tool call JSON. For other detectors, this is a no-op. - - Returns: - A tuple of (normal_text, calls) like parse_stream_chunk. - """ - if not self.tools: - return "", [] - if hasattr(self.detector, "finalize_streaming"): - sp_result = self.detector.finalize_streaming(self.tools) - return sp_result.normal_text, sp_result.calls - return "", [] diff --git a/lightllm/server/tokenizer.py b/lightllm/server/tokenizer.py index 09bc938f23..4b4a0d830a 100644 --- a/lightllm/server/tokenizer.py +++ b/lightllm/server/tokenizer.py @@ -45,6 +45,17 @@ def get_tokenizer( **kwargs, ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: """Gets a tokenizer for the given model name via Huggingface.""" + # DeepSeek-V3.2 custom tokenizer mode: wraps the HF tokenizer with + # a Python-based apply_chat_template that uses encoding_dsv32.py. + if tokenizer_mode == "deepseek_v32": + from ..models.deepseek3_2.model import DeepSeekV32Tokenizer + + hf_tokenizer = AutoTokenizer.from_pretrained( + tokenizer_name, trust_remote_code=trust_remote_code, *args, **kwargs + ) + logger.info("Using DeepSeek-V3.2 tokenizer mode with Python-based chat template encoding.") + return DeepSeekV32Tokenizer(hf_tokenizer) + if tokenizer_mode == "slow": if kwargs.get("use_fast", False): raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.") From 2ca681cbfe5901e518c0f4759cd34fd1e52b6ab9 Mon Sep 17 00:00:00 2001 From: Developer Date: Wed, 4 Feb 2026 13:50:50 +0000 Subject: [PATCH 20/58] fix --- .../layer_infer/transformer_layer_infer.py | 5 --- .../layer_weights/nsa_indexer_layer_weight.py | 2 - lightllm/models/deepseek3_2/mem_manager.py | 4 -- lightllm/models/deepseek3_2/model.py | 45 +++++++++++++------ 4 files changed, 31 insertions(+), 25 deletions(-) diff --git a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py index 9dba923cc1..13a0c1394f 100644 --- a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py @@ -1,5 +1,3 @@ -from typing import override - import torch from lightllm.models.deepseek2.layer_infer.transformer_layer_infer import Deepseek2TransformerLayerInfer @@ -33,7 +31,6 @@ def _get_nsa_backend(self): self._nsa_backend = self._nsa_backend_class(model=None) return self._nsa_backend - @override def _get_qkv( self, input: torch.Tensor, @@ -68,7 +65,6 @@ def _get_qkv( ) return q, cache_kv - @override def _context_attention_kernel( self, q: torch.Tensor, @@ -104,7 +100,6 @@ def _context_attention_kernel( ) return mla_out - @override def _token_attention_kernel( self, q, diff --git a/lightllm/models/deepseek3_2/layer_weights/nsa_indexer_layer_weight.py b/lightllm/models/deepseek3_2/layer_weights/nsa_indexer_layer_weight.py index 6df1a88215..023b89979b 100644 --- a/lightllm/models/deepseek3_2/layer_weights/nsa_indexer_layer_weight.py +++ b/lightllm/models/deepseek3_2/layer_weights/nsa_indexer_layer_weight.py @@ -11,14 +11,12 @@ def __init__(self, layer_num, data_type, network_config, quant_cfg): super().__init__(layer_num, data_type, network_config, quant_cfg) return - @override def _parse_config(self): self.q_lora_rank = self.network_config_["q_lora_rank"] self.index_n_heads = self.network_config_["index_n_heads"] self.index_head_dim = self.network_config_["index_head_dim"] self.hidden_size = self.network_config_["hidden_size"] - @override def _init_weight(self): prefix = f"model.layers.{self.layer_num_}.self_attn.indexer" diff --git a/lightllm/models/deepseek3_2/mem_manager.py b/lightllm/models/deepseek3_2/mem_manager.py index dc78f1de4c..fdb2e87c6b 100644 --- a/lightllm/models/deepseek3_2/mem_manager.py +++ b/lightllm/models/deepseek3_2/mem_manager.py @@ -1,4 +1,3 @@ -from typing_extensions import override import torch from lightllm.common.kv_cache_mem_manager.mem_manager import MemoryManager @@ -15,16 +14,13 @@ def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy, mem_fraction) self.indexer_ks_buffer = IndexerKSBuffer(self.size, 1, 132, layer_num) - @override def get_cell_size(self): return super().get_cell_size() + 132 - @override def _free_buffers(self): super()._free_buffers() self.indexer_ks_buffer = None - @override def resize_mem(self, new_size): super().resize_mem(new_size) self.indexer_ks_buffer = IndexerKSBuffer(self.size, 1, 132, self.layer_num) diff --git a/lightllm/models/deepseek3_2/model.py b/lightllm/models/deepseek3_2/model.py index f907b0bed6..77804096b1 100644 --- a/lightllm/models/deepseek3_2/model.py +++ b/lightllm/models/deepseek3_2/model.py @@ -1,17 +1,26 @@ import copy import json import logging +import os from lightllm.models.registry import ModelRegistry from lightllm.models.deepseek2.model import Deepseek2TpPartModel -from lightllm.models.deepseek3_2.layer_weights.transformer_layer_weight import Deepseek3_2TransformerLayerWeight -from lightllm.models.deepseek3_2.layer_infer.transformer_layer_infer import Deepseek3_2TransformerLayerInfer from lightllm.utils.envs_utils import get_env_start_args -from lightllm.models.deepseek3_2.infer_struct import Deepseek3_2InferStateInfo -from lightllm.models.deepseek3_2.mem_manager import Deepseek3_2MemoryManager, Deepseek3_2FP8KVMemoryManager _logger = logging.getLogger(__name__) +# When ENABLE_NSA is set, use the full V32 NSA (Native Sparse Attention) pipeline +# including the indexer, custom memory manager, and NSA-aware attention kernels. +# When not set, fall back to the DeepSeek V3 (Deepseek2) inference path while +# keeping V32-specific tokenizer/parser support intact. +_ENABLE_NSA = os.environ.get("ENABLE_NSA", "0").lower() in ("1", "true") + +if _ENABLE_NSA: + from lightllm.models.deepseek3_2.layer_weights.transformer_layer_weight import Deepseek3_2TransformerLayerWeight + from lightllm.models.deepseek3_2.layer_infer.transformer_layer_infer import Deepseek3_2TransformerLayerInfer + from lightllm.models.deepseek3_2.infer_struct import Deepseek3_2InferStateInfo + from lightllm.models.deepseek3_2.mem_manager import Deepseek3_2MemoryManager, Deepseek3_2FP8KVMemoryManager + class DeepSeekV32Tokenizer: """Tokenizer wrapper for DeepSeek-V3.2 that uses the Python-based @@ -105,24 +114,32 @@ def apply_chat_template( @ModelRegistry(["deepseek_v32"]) class Deepseek3_2TpPartModel(Deepseek2TpPartModel): - # weight class - transformer_weight_class = Deepseek3_2TransformerLayerWeight - - # infer class - transformer_layer_infer_class = Deepseek3_2TransformerLayerInfer - - # infer state class - infer_state_class = Deepseek3_2InferStateInfo + # When ENABLE_NSA is set, override with V32-specific NSA classes. + # Otherwise, inherit the V3/V2 classes from Deepseek2TpPartModel. + if _ENABLE_NSA: + transformer_weight_class = Deepseek3_2TransformerLayerWeight + transformer_layer_infer_class = Deepseek3_2TransformerLayerInfer + infer_state_class = Deepseek3_2InferStateInfo def __init__(self, kvargs): super().__init__(kvargs) - self.index_topk = self.config["index_topk"] + if _ENABLE_NSA: + self.index_topk = self.config["index_topk"] + else: + _logger.info("ENABLE_NSA is not set, using DeepSeek V3 inference path (no NSA indexer).") return def _init_inferstate_cls(self): - self.infer_state_class = Deepseek3_2InferStateInfo + if _ENABLE_NSA: + self.infer_state_class = Deepseek3_2InferStateInfo + else: + super()._init_inferstate_cls() def _init_mem_manager(self): + if not _ENABLE_NSA: + # Fall back to the standard V3/V2 memory manager (no indexer buffer). + return super()._init_mem_manager() + manager_class = Deepseek3_2MemoryManager if get_env_start_args().llm_kv_type == "fp8kv": manager_class = Deepseek3_2FP8KVMemoryManager From f412e42c337a7f86b3b7e44e4bea6a64f40305ab Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Wed, 4 Feb 2026 14:33:57 +0000 Subject: [PATCH 21/58] deepseekv32 model_type condition --- lightllm/server/tokenizer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/lightllm/server/tokenizer.py b/lightllm/server/tokenizer.py index 4b4a0d830a..bbbb052b85 100644 --- a/lightllm/server/tokenizer.py +++ b/lightllm/server/tokenizer.py @@ -44,10 +44,12 @@ def get_tokenizer( *args, **kwargs, ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: + model_cfg, _ = PretrainedConfig.get_config_dict(tokenizer_name) + model_type = model_cfg.get("model_type", "") """Gets a tokenizer for the given model name via Huggingface.""" # DeepSeek-V3.2 custom tokenizer mode: wraps the HF tokenizer with # a Python-based apply_chat_template that uses encoding_dsv32.py. - if tokenizer_mode == "deepseek_v32": + if model_type == "deepseek_v32": from ..models.deepseek3_2.model import DeepSeekV32Tokenizer hf_tokenizer = AutoTokenizer.from_pretrained( @@ -87,8 +89,6 @@ def get_tokenizer( "slowdown. Consider using a fast tokenizer instead." ) - model_cfg, _ = PretrainedConfig.get_config_dict(tokenizer_name) - model_type = model_cfg.get("model_type", "") if model_cfg["architectures"][0] == "TarsierForConditionalGeneration": from ..models.qwen2_vl.vision_process import Qwen2VLImageProcessor From d72f0859d6a02f3a1603cf5bc4e3328baec89897 Mon Sep 17 00:00:00 2001 From: Developer Date: Thu, 5 Feb 2026 09:14:26 +0000 Subject: [PATCH 22/58] fix v1 streaming --- lightllm/server/api_openai.py | 40 +++++++++++++++++++++++------------ 1 file changed, 27 insertions(+), 13 deletions(-) diff --git a/lightllm/server/api_openai.py b/lightllm/server/api_openai.py index 11e24612b0..a98094c369 100644 --- a/lightllm/server/api_openai.py +++ b/lightllm/server/api_openai.py @@ -339,6 +339,7 @@ async def chat_completions_impl(request: ChatCompletionRequest, raw_request: Req # Streaming case async def stream_results() -> AsyncGenerator[bytes, None]: finish_reason = None + has_emitted_tool_calls = False from .req_id_generator import convert_sub_id_to_group_id prompt_tokens = 0 @@ -359,7 +360,7 @@ async def stream_results() -> AsyncGenerator[bytes, None]: if reasoning_text: choice_data = ChatCompletionStreamResponseChoice( index=0, - delta=DeltaMessage(reasoning_content=reasoning_text), + delta=DeltaMessage(role="assistant", reasoning_content=reasoning_text), finish_reason=None, ) chunk = ChatCompletionStreamResponse( @@ -387,8 +388,8 @@ async def stream_results() -> AsyncGenerator[bytes, None]: if normal_text: choice_data = ChatCompletionStreamResponseChoice( index=0, - delta=DeltaMessage(content=normal_text), - finish_reason=finish_reason if finish_reason else None, + delta=DeltaMessage(role="assistant", content=normal_text), + finish_reason=None, ) chunk = ChatCompletionStreamResponse( id=group_request_id, @@ -401,6 +402,7 @@ async def stream_results() -> AsyncGenerator[bytes, None]: # 2) if we found calls, we output them as separate chunk(s) history_tool_calls_cnt = _get_history_tool_calls_cnt(request) for call_item in calls: + has_emitted_tool_calls = True # transform call_item -> FunctionResponse + ToolCall if finish_reason == "stop": latest_delta_len = 0 @@ -437,7 +439,7 @@ async def stream_results() -> AsyncGenerator[bytes, None]: choice_data = ChatCompletionStreamResponseChoice( index=0, delta=DeltaMessage(role="assistant", tool_calls=[tool_call]), - finish_reason="tool_calls", + finish_reason=None, ) chunk = ChatCompletionStreamResponse( id=group_request_id, @@ -447,22 +449,34 @@ async def stream_results() -> AsyncGenerator[bytes, None]: ) yield f"data: {chunk.model_dump_json()}\n\n" else: - group_request_id = convert_sub_id_to_group_id(sub_req_id) - delta_message = DeltaMessage(role="assistant", content=delta) - if finish_status.is_finished(): - finish_reason = finish_status.get_finish_reason() - stream_choice = ChatCompletionStreamResponseChoice( - index=0, delta=delta_message, finish_reason=finish_reason - ) + stream_choice = ChatCompletionStreamResponseChoice(index=0, delta=delta_message, finish_reason=None) stream_resp = ChatCompletionStreamResponse( id=group_request_id, created=created_time, model=request.model, choices=[stream_choice], ) - yield ("data: " + json.dumps(stream_resp.dict(), ensure_ascii=False) + "\n\n").encode("utf-8") - # Additional usage chunk + yield f"data: {stream_resp.model_dump_json()}\n\n" + + # Determine final finish_reason: override to "tool_calls" if tool calls were emitted + if has_emitted_tool_calls and finish_reason == "stop": + finish_reason = "tool_calls" + + # Final empty chunk containing only finish_reason (and role) + if finish_reason is not None: + final_choice = ChatCompletionStreamResponseChoice( + index=0, + delta=DeltaMessage(), + finish_reason=finish_reason, + ) + final_chunk = ChatCompletionStreamResponse( + id=group_request_id, + created=created_time, + model=request.model, + choices=[final_choice], + ) + yield f"data: {final_chunk.model_dump_json()}\n\n" if request.stream_options and request.stream_options.include_usage: usage = UsageInfo( From dd1a067f8b6fa50ebf4a7900107b22fefde79294 Mon Sep 17 00:00:00 2001 From: Developer Date: Thu, 5 Feb 2026 09:19:11 +0000 Subject: [PATCH 23/58] exclude_none --- lightllm/server/api_openai.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/lightllm/server/api_openai.py b/lightllm/server/api_openai.py index a98094c369..b06d8068ce 100644 --- a/lightllm/server/api_openai.py +++ b/lightllm/server/api_openai.py @@ -369,7 +369,7 @@ async def stream_results() -> AsyncGenerator[bytes, None]: choices=[choice_data], model=request.model, ) - yield f"data: {chunk.model_dump_json()}\n\n" + yield f"data: {chunk.model_dump_json(exclude_none=True)}\n\n" if request.tool_choice != "none" and request.tools: if index not in parser_dict: @@ -397,7 +397,7 @@ async def stream_results() -> AsyncGenerator[bytes, None]: choices=[choice_data], model=request.model, ) - yield f"data: {chunk.model_dump_json()}\n\n" + yield f"data: {chunk.model_dump_json(exclude_none=True)}\n\n" # 2) if we found calls, we output them as separate chunk(s) history_tool_calls_cnt = _get_history_tool_calls_cnt(request) @@ -447,7 +447,7 @@ async def stream_results() -> AsyncGenerator[bytes, None]: choices=[choice_data], model=request.model, ) - yield f"data: {chunk.model_dump_json()}\n\n" + yield f"data: {chunk.model_dump_json(exclude_none=True)}\n\n" else: delta_message = DeltaMessage(role="assistant", content=delta) stream_choice = ChatCompletionStreamResponseChoice(index=0, delta=delta_message, finish_reason=None) @@ -457,7 +457,7 @@ async def stream_results() -> AsyncGenerator[bytes, None]: model=request.model, choices=[stream_choice], ) - yield f"data: {stream_resp.model_dump_json()}\n\n" + yield f"data: {stream_resp.model_dump_json(exclude_none=True)}\n\n" # Determine final finish_reason: override to "tool_calls" if tool calls were emitted if has_emitted_tool_calls and finish_reason == "stop": @@ -476,7 +476,7 @@ async def stream_results() -> AsyncGenerator[bytes, None]: model=request.model, choices=[final_choice], ) - yield f"data: {final_chunk.model_dump_json()}\n\n" + yield f"data: {final_chunk.model_dump_json(exclude_none=True)}\n\n" if request.stream_options and request.stream_options.include_usage: usage = UsageInfo( @@ -491,7 +491,7 @@ async def stream_results() -> AsyncGenerator[bytes, None]: model=request.model, usage=usage, ) - yield f"data: {usage_chunk.model_dump_json()}\n\n" + yield f"data: {usage_chunk.model_dump_json(exclude_none=True)}\n\n" background_tasks = BackgroundTasks() return StreamingResponse(stream_results(), media_type="text/event-stream", background=background_tasks) @@ -693,7 +693,7 @@ async def stream_results() -> AsyncGenerator[bytes, None]: model=request.model, usage=usage, ) - yield f"data: {usage_chunk.model_dump_json()}\n\n" + yield f"data: {usage_chunk.model_dump_json(exclude_none=True)}\n\n" background_tasks = BackgroundTasks() return StreamingResponse(stream_results(), media_type="text/event-stream", background=background_tasks) From f352854c891b88dc36f81f00b18ea7e133a25b0a Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Tue, 3 Mar 2026 08:30:18 +0000 Subject: [PATCH 24/58] fix deepseekv3.2 --- .../basemodel/attention/create_utils.py | 46 ++++++----- .../common/kv_cache_mem_manager/__init__.py | 2 + .../deepseek3_2mem_manager.py} | 6 -- .../common/kv_cache_mem_manager/mem_utils.py | 9 +++ .../{infer_struct.py => _del_infer_struct.py} | 0 lightllm/models/deepseek3_2/model.py | 79 ++++--------------- lightllm/utils/backend_validator.py | 43 ++++++++++ 7 files changed, 95 insertions(+), 90 deletions(-) rename lightllm/{models/deepseek3_2/mem_manager.py => common/kv_cache_mem_manager/deepseek3_2mem_manager.py} (74%) rename lightllm/models/deepseek3_2/{infer_struct.py => _del_infer_struct.py} (100%) diff --git a/lightllm/common/basemodel/attention/create_utils.py b/lightllm/common/basemodel/attention/create_utils.py index e3bf81daed..1fcde2a5ca 100644 --- a/lightllm/common/basemodel/attention/create_utils.py +++ b/lightllm/common/basemodel/attention/create_utils.py @@ -1,10 +1,8 @@ """Attention backend selection utilities.""" - -import os -import torch from lightllm.utils.envs_utils import get_env_start_args from lightllm.utils.log_utils import init_logger from lightllm.utils.backend_validator import validate +from typing import Dict from .base_att import BaseAttBackend from .triton.fp import TritonAttBackend from .triton.int4kv import Int4kvTritonAttBackend @@ -56,14 +54,16 @@ def _auto_select_backend( - llm_dtype: str, is_mla: bool = False, priority_list: list = ["fa3", "flashinfer", "triton"] + llm_dtype: str, + kv_type_to_backend: Dict[str, Dict[str, BaseAttBackend]], + priority_list: list = ["fa3", "flashinfer", "triton"], ) -> type: """Auto-select the best available backend with validation. Priority: FA3 > FlashInfer > Triton Each backend is validated in a subprocess with ground truth checks. """ - backend_map = mla_data_type_to_backend if is_mla else data_type_to_backend + backend_map = kv_type_to_backend for backend_name in priority_list: if validate(backend_name): @@ -82,7 +82,7 @@ def get_prefill_att_backend_class(index=0, priority_list: list = ["fa3", "flashi if backend_str != "auto": return data_type_to_backend[llm_dtype][backend_str] else: - return _auto_select_backend(llm_dtype, is_mla=False, priority_list=priority_list) + return _auto_select_backend(llm_dtype, kv_type_to_backend=data_type_to_backend, priority_list=priority_list) def get_decode_att_backend_class(index=0, priority_list: list = ["fa3", "flashinfer", "triton"]) -> BaseAttBackend: @@ -92,7 +92,7 @@ def get_decode_att_backend_class(index=0, priority_list: list = ["fa3", "flashin if backend_str != "auto": return data_type_to_backend[llm_dtype][backend_str] else: - return _auto_select_backend(llm_dtype, is_mla=False, priority_list=priority_list) + return _auto_select_backend(llm_dtype, kv_type_to_backend=data_type_to_backend, priority_list=priority_list) def get_mla_prefill_att_backend_class(index=0, priority_list: list = ["fa3", "flashinfer", "triton"]) -> BaseAttBackend: @@ -102,7 +102,7 @@ def get_mla_prefill_att_backend_class(index=0, priority_list: list = ["fa3", "fl if backend_str != "auto": return mla_data_type_to_backend[llm_dtype][backend_str] else: - return _auto_select_backend(llm_dtype, is_mla=True, priority_list=priority_list) + return _auto_select_backend(llm_dtype, kv_type_to_backend=mla_data_type_to_backend, priority_list=priority_list) def get_mla_decode_att_backend_class(index=0, priority_list: list = ["fa3", "flashinfer", "triton"]) -> BaseAttBackend: @@ -112,20 +112,24 @@ def get_mla_decode_att_backend_class(index=0, priority_list: list = ["fa3", "fla if backend_str != "auto": return mla_data_type_to_backend[llm_dtype][backend_str] else: - return _auto_select_backend(llm_dtype, is_mla=True, priority_list=priority_list) + return _auto_select_backend(llm_dtype, kv_type_to_backend=mla_data_type_to_backend, priority_list=priority_list) -def get_nsa_prefill_att_backend_class(backend_str: str = "flashmla_sparse") -> BaseAttBackend: - llm_dtype = "None" - if backend_str not in nsa_data_type_to_backend[llm_dtype]: - logger.warning(f"NSA backend '{backend_str}' not found, falling back to flashmla_sparse") - backend_str = "flashmla_sparse" - return nsa_data_type_to_backend[llm_dtype][backend_str] +def get_nsa_prefill_att_backend_class(index=0, priority_list: list = ["flashmla_sparse"]) -> BaseAttBackend: + args = get_env_start_args() + llm_dtype = args.llm_kv_type + backend_str = args.llm_prefill_att_backend[index] + if backend_str != "auto": + return nsa_data_type_to_backend[llm_dtype][backend_str] + else: + return _auto_select_backend(llm_dtype, kv_type_to_backend=nsa_data_type_to_backend, priority_list=priority_list) -def get_nsa_decode_att_backend_class(backend_str: str = "flashmla_sparse") -> BaseAttBackend: - llm_dtype = "None" - if backend_str not in nsa_data_type_to_backend[llm_dtype]: - logger.warning(f"NSA backend '{backend_str}' not found, falling back to flashmla_sparse") - backend_str = "flashmla_sparse" - return nsa_data_type_to_backend[llm_dtype][backend_str] +def get_nsa_decode_att_backend_class(index=0, priority_list: list = ["flashmla_sparse"]) -> BaseAttBackend: + args = get_env_start_args() + llm_dtype = args.llm_kv_type + backend_str = args.llm_decode_att_backend[index] + if backend_str != "auto": + return nsa_data_type_to_backend[llm_dtype][backend_str] + else: + return _auto_select_backend(llm_dtype, kv_type_to_backend=nsa_data_type_to_backend, priority_list=priority_list) diff --git a/lightllm/common/kv_cache_mem_manager/__init__.py b/lightllm/common/kv_cache_mem_manager/__init__.py index 7d516e6728..d41d1555a3 100644 --- a/lightllm/common/kv_cache_mem_manager/__init__.py +++ b/lightllm/common/kv_cache_mem_manager/__init__.py @@ -4,6 +4,7 @@ from .ppl_int8kv_mem_manager import PPLINT8KVMemoryManager from .ppl_int4kv_mem_manager import PPLINT4KVMemoryManager from .deepseek2_mem_manager import Deepseek2MemoryManager +from .deepseek3_2mem_manager import Deepseek3_2MemoryManager __all__ = [ "MemoryManager", @@ -13,4 +14,5 @@ "PPLINT4KVMemoryManager", "PPLINT8KVMemoryManager", "Deepseek2MemoryManager", + "Deepseek3_2MemoryManager", ] diff --git a/lightllm/models/deepseek3_2/mem_manager.py b/lightllm/common/kv_cache_mem_manager/deepseek3_2mem_manager.py similarity index 74% rename from lightllm/models/deepseek3_2/mem_manager.py rename to lightllm/common/kv_cache_mem_manager/deepseek3_2mem_manager.py index fdb2e87c6b..679a07bde4 100644 --- a/lightllm/models/deepseek3_2/mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/deepseek3_2mem_manager.py @@ -1,6 +1,5 @@ import torch -from lightllm.common.kv_cache_mem_manager.mem_manager import MemoryManager from lightllm.common.kv_cache_mem_manager.deepseek2_mem_manager import Deepseek2MemoryManager @@ -24,8 +23,3 @@ def _free_buffers(self): def resize_mem(self, new_size): super().resize_mem(new_size) self.indexer_ks_buffer = IndexerKSBuffer(self.size, 1, 132, self.layer_num) - - -class Deepseek3_2FP8KVMemoryManager(Deepseek3_2MemoryManager): - def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9): - super().__init__(size, torch.uint8, head_num, head_dim + 2, layer_num, always_copy, mem_fraction) diff --git a/lightllm/common/kv_cache_mem_manager/mem_utils.py b/lightllm/common/kv_cache_mem_manager/mem_utils.py index 1ff58b89a0..b22590a6f2 100644 --- a/lightllm/common/kv_cache_mem_manager/mem_utils.py +++ b/lightllm/common/kv_cache_mem_manager/mem_utils.py @@ -5,6 +5,7 @@ PPLINT8KVMemoryManager, PPLINT4KVMemoryManager, Deepseek2MemoryManager, + Deepseek3_2MemoryManager, ) from lightllm.utils.log_utils import init_logger from lightllm.utils.envs_utils import get_env_start_args @@ -19,6 +20,14 @@ def select_mem_manager_class(): # case 1 # 先判断是否是 deepseek 系列的模型 model_class = get_llm_model_class() + + from lightllm.models import Deepseek3_2TpPartModel + + if issubclass(model_class, Deepseek3_2TpPartModel): + mem_class = Deepseek3_2MemoryManager + logger.info(f"Model kv cache using default, mem_manager class: {mem_class}") + return mem_class + from lightllm.models import Deepseek2TpPartModel if issubclass(model_class, Deepseek2TpPartModel): diff --git a/lightllm/models/deepseek3_2/infer_struct.py b/lightllm/models/deepseek3_2/_del_infer_struct.py similarity index 100% rename from lightllm/models/deepseek3_2/infer_struct.py rename to lightllm/models/deepseek3_2/_del_infer_struct.py diff --git a/lightllm/models/deepseek3_2/model.py b/lightllm/models/deepseek3_2/model.py index 77804096b1..09b6be4dbf 100644 --- a/lightllm/models/deepseek3_2/model.py +++ b/lightllm/models/deepseek3_2/model.py @@ -1,25 +1,10 @@ import copy -import json -import logging -import os - from lightllm.models.registry import ModelRegistry from lightllm.models.deepseek2.model import Deepseek2TpPartModel -from lightllm.utils.envs_utils import get_env_start_args - -_logger = logging.getLogger(__name__) - -# When ENABLE_NSA is set, use the full V32 NSA (Native Sparse Attention) pipeline -# including the indexer, custom memory manager, and NSA-aware attention kernels. -# When not set, fall back to the DeepSeek V3 (Deepseek2) inference path while -# keeping V32-specific tokenizer/parser support intact. -_ENABLE_NSA = os.environ.get("ENABLE_NSA", "0").lower() in ("1", "true") - -if _ENABLE_NSA: - from lightllm.models.deepseek3_2.layer_weights.transformer_layer_weight import Deepseek3_2TransformerLayerWeight - from lightllm.models.deepseek3_2.layer_infer.transformer_layer_infer import Deepseek3_2TransformerLayerInfer - from lightllm.models.deepseek3_2.infer_struct import Deepseek3_2InferStateInfo - from lightllm.models.deepseek3_2.mem_manager import Deepseek3_2MemoryManager, Deepseek3_2FP8KVMemoryManager +from lightllm.models.deepseek3_2.layer_weights.transformer_layer_weight import Deepseek3_2TransformerLayerWeight +from lightllm.models.deepseek3_2.layer_infer.transformer_layer_infer import Deepseek3_2TransformerLayerInfer +from lightllm.models.deepseek3_2.infer_struct import Deepseek3_2InferStateInfo +from lightllm.common.basemodel.attention import get_nsa_prefill_att_backend_class, get_nsa_decode_att_backend_class class DeepSeekV32Tokenizer: @@ -114,49 +99,17 @@ def apply_chat_template( @ModelRegistry(["deepseek_v32"]) class Deepseek3_2TpPartModel(Deepseek2TpPartModel): - # When ENABLE_NSA is set, override with V32-specific NSA classes. - # Otherwise, inherit the V3/V2 classes from Deepseek2TpPartModel. - if _ENABLE_NSA: - transformer_weight_class = Deepseek3_2TransformerLayerWeight - transformer_layer_infer_class = Deepseek3_2TransformerLayerInfer - infer_state_class = Deepseek3_2InferStateInfo - - def __init__(self, kvargs): - super().__init__(kvargs) - if _ENABLE_NSA: - self.index_topk = self.config["index_topk"] - else: - _logger.info("ENABLE_NSA is not set, using DeepSeek V3 inference path (no NSA indexer).") - return - def _init_inferstate_cls(self): - if _ENABLE_NSA: - self.infer_state_class = Deepseek3_2InferStateInfo - else: - super()._init_inferstate_cls() - - def _init_mem_manager(self): - if not _ENABLE_NSA: - # Fall back to the standard V3/V2 memory manager (no indexer buffer). - return super()._init_mem_manager() - - manager_class = Deepseek3_2MemoryManager - if get_env_start_args().llm_kv_type == "fp8kv": - manager_class = Deepseek3_2FP8KVMemoryManager - - # mtp 模式下需要在mem manger上扩展draft model使用的layer - added_mtp_layer_num = 0 - if get_env_start_args().mtp_mode == "deepseekv3_eagle": - added_mtp_layer_num += 1 - elif get_env_start_args().mtp_mode == "deepseekv3_vanilla": - added_mtp_layer_num += get_env_start_args().mtp_step - - self.mem_manager = manager_class( - self.max_total_token_num, - dtype=self.data_type, - head_num=1, - head_dim=self.config["kv_lora_rank"] + self.config["qk_rope_head_dim"], - layer_num=self.config["num_hidden_layers"] + added_mtp_layer_num, - mem_fraction=self.mem_fraction, - ) + # weight class + transformer_weight_class = Deepseek3_2TransformerLayerWeight + + # infer class + transformer_layer_infer_class = Deepseek3_2TransformerLayerInfer + + # infer state class + infer_state_class = Deepseek3_2InferStateInfo + + def _init_att_backend(self): + self.prefill_att_backend = get_nsa_prefill_att_backend_class(index=0)(model=self) + self.decode_att_backend = get_nsa_decode_att_backend_class(index=0)(model=self) return diff --git a/lightllm/utils/backend_validator.py b/lightllm/utils/backend_validator.py index 0e2f9c962e..6c5fe90309 100644 --- a/lightllm/utils/backend_validator.py +++ b/lightllm/utils/backend_validator.py @@ -187,6 +187,47 @@ def _validate_sdpa(): return True, None +def _validate_flashmla_sparse(): + """Validate flashmla_sparse (NSA) with ground truth.""" + try: + from sgl_kernel.flash_mla import flash_mla_sparse_fwd + + # need sgl-kernel version >= 0.3.21 and torch >= 2.9.1 + except Exception as e: + return False, f"sgl_kernel.flash_mla import failed: {type(e).__name__}: {e}" + + batch, heads, seq, dim = 1, 64, 128, 512 + 64 + dtype = torch.bfloat16 + device = "cuda" + + q = torch.randn(batch * seq, heads, dim, dtype=dtype, device=device) + kv = torch.zeros(batch * seq, 1, dim, dtype=dtype, device=device) + + index_topk = 128 + topk_indices = torch.zeros(batch * seq, index_topk, dtype=torch.int32, device=device) + for i in range(seq): + topk_indices[i, :] = torch.arange(index_topk, dtype=torch.int32, device=device) + + topk_indices = topk_indices.view(batch * seq, 1, index_topk) + + softmax_scale = 1.0 / (dim ** 0.5) + kv_lora_rank = dim + + try: + mla_out, _, _ = flash_mla_sparse_fwd( + q=q, + kv=kv, + indices=topk_indices, + sm_scale=softmax_scale, + d_v=kv_lora_rank, + ) + torch.cuda.synchronize() + except Exception as e: + return False, f"flash_mla_sparse_fwd run failed: {type(e).__name__}: {e}" + + return True, None + + def _run_in_subprocess(backend_name, pipe): """Run validation in subprocess with suppressed output.""" import sys @@ -207,6 +248,8 @@ def _run_in_subprocess(backend_name, pipe): success, err = _validate_flashinfer() elif backend_name == "triton": success, err = _validate_triton() + elif backend_name == "flashmla_sparse": + success, err = _validate_flashmla_sparse() else: success, err = False, f"Unknown backend: {backend_name}" pipe.send((success, err)) From 579a47e0251bb9c62e29d0497948276990ab2014 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Tue, 3 Mar 2026 08:41:54 +0000 Subject: [PATCH 25/58] fix --- .../layer_weights/nsa_indexer_layer_weight.py | 55 ------------------- .../layer_weights/transformer_layer_weight.py | 54 ++++++++++++++++-- 2 files changed, 50 insertions(+), 59 deletions(-) delete mode 100644 lightllm/models/deepseek3_2/layer_weights/nsa_indexer_layer_weight.py diff --git a/lightllm/models/deepseek3_2/layer_weights/nsa_indexer_layer_weight.py b/lightllm/models/deepseek3_2/layer_weights/nsa_indexer_layer_weight.py deleted file mode 100644 index 023b89979b..0000000000 --- a/lightllm/models/deepseek3_2/layer_weights/nsa_indexer_layer_weight.py +++ /dev/null @@ -1,55 +0,0 @@ -from typing_extensions import override - -import torch - -from lightllm.common.basemodel.layer_weights.transformer_layer_weight import TransformerLayerWeight -from lightllm.common.basemodel.layer_weights.meta_weights import ROWMMWeight, LayerNormWeight - - -class NSAIndexerWeight(TransformerLayerWeight): - def __init__(self, layer_num, data_type, network_config, quant_cfg): - super().__init__(layer_num, data_type, network_config, quant_cfg) - return - - def _parse_config(self): - self.q_lora_rank = self.network_config_["q_lora_rank"] - self.index_n_heads = self.network_config_["index_n_heads"] - self.index_head_dim = self.network_config_["index_head_dim"] - self.hidden_size = self.network_config_["hidden_size"] - - def _init_weight(self): - prefix = f"model.layers.{self.layer_num_}.self_attn.indexer" - - self.wq_b_proj_ = ROWMMWeight( - in_dim=self.q_lora_rank, - out_dims=[self.index_n_heads * self.index_head_dim], - weight_names=f"{prefix}.wq_b.weight", - data_type=self.data_type_, - quant_method=None, - tp_rank=0, - tp_world_size=1, - ) - self.wk_proj_ = ROWMMWeight( - in_dim=self.hidden_size, - out_dims=[self.index_head_dim], - weight_names=f"{prefix}.wk.weight", - data_type=self.data_type_, - quant_method=None, - tp_rank=0, - tp_world_size=1, - ) - self.k_norm_ = LayerNormWeight( - dim=self.index_head_dim, - weight_name=f"{prefix}.k_norm.weight", - data_type=self.data_type_, - bias_name=f"{prefix}.k_norm.bias", - ) - self.weights_proj_ = ROWMMWeight( - in_dim=self.hidden_size, - out_dims=[self.index_n_heads], - weight_names=f"{prefix}.weights_proj.weight", - data_type=self.data_type_, - quant_method=None, - tp_rank=0, - tp_world_size=1, - ) diff --git a/lightllm/models/deepseek3_2/layer_weights/transformer_layer_weight.py b/lightllm/models/deepseek3_2/layer_weights/transformer_layer_weight.py index adcba51cc9..c8d285db4b 100644 --- a/lightllm/models/deepseek3_2/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/deepseek3_2/layer_weights/transformer_layer_weight.py @@ -1,12 +1,58 @@ from lightllm.models.deepseek2.layer_weights.transformer_layer_weight import Deepseek2TransformerLayerWeight from lightllm.models.deepseek3_2.layer_weights.nsa_indexer_layer_weight import NSAIndexerWeight +from lightllm.common.basemodel.layer_weights.meta_weights import ROWMMWeight, LayerNormWeight class Deepseek3_2TransformerLayerWeight(Deepseek2TransformerLayerWeight): def __init__(self, layer_num, data_type, network_config, quant_cfg=None): - self.index_topk = network_config["index_topk"] super().__init__(layer_num, data_type, network_config, quant_cfg) - self.indexer_layer_weight = NSAIndexerWeight( - layer_num=layer_num, data_type=data_type, network_config=network_config, quant_cfg=quant_cfg - ) return + + def _parse_config(self): + super()._parse_config() + self.q_lora_rank = self.network_config_["q_lora_rank"] + self.index_n_heads = self.network_config_["index_n_heads"] + self.index_head_dim = self.network_config_["index_head_dim"] + self.hidden_size = self.network_config_["hidden_size"] + + def _init_weight(self): + super()._init_weight() + self._init_indexer_weight() + + def _init_indexer_weight(self): + + prefix = f"model.layers.{self.layer_num_}.self_attn.indexer" + + self.wq_b_proj_ = ROWMMWeight( + in_dim=self.q_lora_rank, + out_dims=[self.index_n_heads * self.index_head_dim], + weight_names=f"{prefix}.wq_b.weight", + data_type=self.data_type_, + quant_method=None, + tp_rank=0, + tp_world_size=1, + ) + self.wk_proj_ = ROWMMWeight( + in_dim=self.hidden_size, + out_dims=[self.index_head_dim], + weight_names=f"{prefix}.wk.weight", + data_type=self.data_type_, + quant_method=None, + tp_rank=0, + tp_world_size=1, + ) + self.k_norm_ = LayerNormWeight( + dim=self.index_head_dim, + weight_name=f"{prefix}.k_norm.weight", + data_type=self.data_type_, + bias_name=f"{prefix}.k_norm.bias", + ) + self.weights_proj_ = ROWMMWeight( + in_dim=self.hidden_size, + out_dims=[self.index_n_heads], + weight_names=f"{prefix}.weights_proj.weight", + data_type=self.data_type_, + quant_method=None, + tp_rank=0, + tp_world_size=1, + ) From 8b6b5b7fcd24ee3b569758f517ed7733bed908f7 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Tue, 3 Mar 2026 08:43:00 +0000 Subject: [PATCH 26/58] fix --- .../models/deepseek3_2/layer_weights/transformer_layer_weight.py | 1 - 1 file changed, 1 deletion(-) diff --git a/lightllm/models/deepseek3_2/layer_weights/transformer_layer_weight.py b/lightllm/models/deepseek3_2/layer_weights/transformer_layer_weight.py index c8d285db4b..eb14c82b49 100644 --- a/lightllm/models/deepseek3_2/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/deepseek3_2/layer_weights/transformer_layer_weight.py @@ -1,5 +1,4 @@ from lightllm.models.deepseek2.layer_weights.transformer_layer_weight import Deepseek2TransformerLayerWeight -from lightllm.models.deepseek3_2.layer_weights.nsa_indexer_layer_weight import NSAIndexerWeight from lightllm.common.basemodel.layer_weights.meta_weights import ROWMMWeight, LayerNormWeight From c0157d3f97c6693a010cd443bc31961b12fb711c Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Tue, 3 Mar 2026 10:14:14 +0000 Subject: [PATCH 27/58] fix --- .../layer_infer/transformer_layer_infer.py | 20 ++----------------- 1 file changed, 2 insertions(+), 18 deletions(-) diff --git a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py index 13a0c1394f..9a772c17ff 100644 --- a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py @@ -3,12 +3,11 @@ from lightllm.models.deepseek2.layer_infer.transformer_layer_infer import Deepseek2TransformerLayerInfer from lightllm.models.deepseek3_2.layer_infer.nsa_indexer_layer_inder import NSAIndexerInfer from lightllm.models.deepseek3_2.layer_weights.transformer_layer_weight import Deepseek3_2TransformerLayerWeight -from lightllm.models.deepseek3_2.infer_struct import Deepseek3_2InferStateInfo +from lightllm.models.deepseek3_2._del_infer_struct import Deepseek3_2InferStateInfo from lightllm.models.deepseek3_2.triton_kernel.token_group_quant import per_token_group_quant_mla_deep_gemm_masked_fp8 from lightllm.common.basemodel.triton_kernel.norm.rmsnorm import rmsnorm_forward from lightllm.models.deepseek2.triton_kernel.rotary_emb import rotary_emb_fwd from lightllm.common.basemodel.attention.base_att import AttControl -from lightllm.common.basemodel.attention.create_utils import get_nsa_prefill_att_backend_class class Deepseek3_2TransformerLayerInfer(Deepseek2TransformerLayerInfer): @@ -18,19 +17,8 @@ def __init__(self, layer_num, network_config): self.indexer = NSAIndexerInfer(layer_idx=self.layer_num_, network_config=self.network_config_) self.topk_indices = None - - # Initialize NSA attention backend (singleton, lazy initialization) - self._nsa_backend_class = get_nsa_prefill_att_backend_class() - self._nsa_backend = None return - def _get_nsa_backend(self): - """Get or create the NSA backend (lazy initialization).""" - if self._nsa_backend is None: - # NSA backend doesn't require model reference for basic operations - self._nsa_backend = self._nsa_backend_class(model=None) - return self._nsa_backend - def _get_qkv( self, input: torch.Tensor, @@ -88,11 +76,7 @@ def _context_attention_kernel( }, ) - # Create prefill state and execute attention - nsa_backend = self._get_nsa_backend() - prefill_state = nsa_backend.create_att_prefill_state(infer_state) - prefill_state.init_state() - mla_out = prefill_state.prefill_att( + mla_out = infer_state.prefill_att_state.prefill_att( q=q_all, k=infer_state.mem_manager.kv_buffer[self.layer_num_], v=None, From f92716aad822373955fa2bf2047fa2bb90d7e1f4 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Tue, 3 Mar 2026 10:16:46 +0000 Subject: [PATCH 28/58] fix --- .../models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py b/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py index 7a9aeb46c9..06c56cd275 100644 --- a/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py +++ b/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py @@ -3,7 +3,7 @@ import torch from lightllm.common.basemodel.layer_infer.base_layer_infer import BaseLayerInfer from lightllm.models.deepseek3_2.layer_weights.nsa_indexer_layer_weight import NSAIndexerWeight -from lightllm.models.deepseek3_2.infer_struct import Deepseek3_2InferStateInfo +from lightllm.models.deepseek3_2._del_infer_struct import Deepseek3_2InferStateInfo from lightllm.models.deepseek2.triton_kernel.rotary_emb import rotary_emb_fwd from lightllm.models.deepseek3_2.triton_kernel.act_quant import act_quant from lightllm.models.deepseek3_2.triton_kernel.destindex_copy_indexer_ks import destindex_copy_indexer_ks From 879e907f77936ae95e9dfc79a9a3305f0817d2bf Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Tue, 3 Mar 2026 20:52:38 +0800 Subject: [PATCH 29/58] fix --- .../layer_infer/nsa_indexer_layer_inder.py | 8 +- .../triton_kernel/extract_indexer_ks.py | 108 +++++++++++------- 2 files changed, 73 insertions(+), 43 deletions(-) diff --git a/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py b/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py index 06c56cd275..8a16ab0dcf 100644 --- a/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py +++ b/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py @@ -92,7 +92,13 @@ def get_indices( # Use efficient Triton kernel to extract FP8 keys and scales from buffer k_fp8_, k_scale_ = extract_indexer_ks( - infer_state.indexer_ks_buffer.kv_buffer[self.layer_idx_], infer_state.req_all_mem_index + I_buffer=infer_state.mem_manager.indexer_ks_buffer.kv_buffer[self.layer_idx_], + b_seq_len=infer_state.b_seq_len, + b_req_idx=infer_state.b_req_idx, + b_cu_kv_seq_len=infer_state.b1_cu_kv_seq_len, + req_to_token_indexs=infer_state.req_manager.req_to_token_indexs, + out_token_num=infer_state.b_seq_len.shape[0] * infer_state.max_kv_seq_len, + max_kv_seq_len=infer_state.max_kv_seq_len, ) # Get actual sequence length from q (which comes from q_lora) diff --git a/lightllm/models/deepseek3_2/triton_kernel/extract_indexer_ks.py b/lightllm/models/deepseek3_2/triton_kernel/extract_indexer_ks.py index 48bc34ad6e..5684eb4114 100644 --- a/lightllm/models/deepseek3_2/triton_kernel/extract_indexer_ks.py +++ b/lightllm/models/deepseek3_2/triton_kernel/extract_indexer_ks.py @@ -6,74 +6,98 @@ @triton.jit def _fwd_kernel_extract_indexer_ks( - I_buffer, # Input buffer [large_size, 1, 132] uint8 - SrcLoc, # Source indices [req_size] int32/int64 - O_fp8, # Output FP8 [req_size, 128] float8_e4m3fn - O_scale, # Output scale [req_size] float32 - stride_i_bs, - stride_i_h, - stride_i_d, + in_fp8, + stride_in_fp8_bs, + stride_in_fp8_h, + stride_in_fp8_d, + in_fp8_scale, + stride_in_scale_bs, + stride_in_scale_h, + stride_in_scale_d, + req_to_token_indexs, + stride_req_to_token_m, + stride_req_to_token_n, + b_seq_len, + b_req_idx, + b_cu_kv_seq_len, + O_fp8, stride_o_fp8_bs, stride_o_fp8_d, + O_scale, stride_o_scale_bs, + stride_o_scale_d, BLOCK_DMODEL: tl.constexpr, ): - cur_index = tl.program_id(0) + cur_req_index = tl.program_id(0) + token_start_index = tl.program_id(1) + cur_req_idx = tl.load(b_req_idx + cur_req_index) + cur_seq_len = tl.load(b_seq_len + cur_req_index) offs_d = tl.arange(0, BLOCK_DMODEL) + store_start_index = tl.load(b_cu_kv_seq_len + cur_req_index) - src_index = tl.load(SrcLoc + cur_index).to(tl.int64) + for i in range(token_start_index, cur_seq_len, tl.num_programs(1)): + mem_index = tl.load(req_to_token_indexs + cur_req_idx * stride_req_to_token_m + i * stride_req_to_token_n) - i_k_ptrs = I_buffer + src_index * stride_i_bs + stride_i_d * offs_d - k_fp8_as_uint8 = tl.load(i_k_ptrs) + in_fp8_ptrs = in_fp8 + mem_index * stride_in_fp8_bs + 0 * stride_in_fp8_h + stride_in_fp8_d * offs_d + kv_fp8 = tl.load(in_fp8_ptrs) - k_fp8 = k_fp8_as_uint8.to(tl.float8e4nv, bitcast=True) + in_scale_ptrs = in_fp8_scale + mem_index * stride_in_scale_bs + 0 * stride_in_scale_h + 0 * stride_in_scale_d + kv_scale = tl.load(in_scale_ptrs) - o_k_ptrs = O_fp8 + cur_index * stride_o_fp8_bs + stride_o_fp8_d * offs_d - tl.store(o_k_ptrs, k_fp8) + o_fp8_ptrs = O_fp8 + (store_start_index + i) * stride_o_fp8_bs + stride_o_fp8_d * offs_d + tl.store(o_fp8_ptrs, kv_fp8) - i_scale_base_ptr = I_buffer + src_index * stride_i_bs + BLOCK_DMODEL * stride_i_d - - byte0 = tl.load(i_scale_base_ptr + 0 * stride_i_d).to(tl.uint32) - byte1 = tl.load(i_scale_base_ptr + 1 * stride_i_d).to(tl.uint32) - byte2 = tl.load(i_scale_base_ptr + 2 * stride_i_d).to(tl.uint32) - byte3 = tl.load(i_scale_base_ptr + 3 * stride_i_d).to(tl.uint32) - - scale_as_uint32 = byte0 | (byte1 << 8) | (byte2 << 16) | (byte3 << 24) - - k_scale = scale_as_uint32.to(tl.float32, bitcast=True) - - o_scale_ptr = O_scale + cur_index * stride_o_scale_bs - tl.store(o_scale_ptr, k_scale) + o_scale_ptr = O_scale + (store_start_index + i) * stride_o_scale_bs + tl.store(o_scale_ptr, kv_scale) return @torch.no_grad() -def extract_indexer_ks(I_buffer: torch.Tensor, SrcLoc: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - req_size = SrcLoc.shape[0] +def extract_indexer_ks( + I_buffer: torch.Tensor, + b_seq_len: torch.Tensor, + b_req_idx: torch.Tensor, + b_cu_kv_seq_len: torch.Tensor, + req_to_token_indexs: torch.Tensor, + out_token_num: int, + max_kv_seq_len: int, +) -> tuple[torch.Tensor, torch.Tensor]: head_dim = 128 assert I_buffer.dtype == torch.uint8, f"Expected I_buffer dtype=uint8, got {I_buffer.dtype}" assert I_buffer.shape[2] == 132, f"Expected I_buffer last dim=132, got {I_buffer.shape[2]}" + in_fp8 = I_buffer[:, :, 0:128].view(dtype=torch.float8_e4m3fn) + in_fp8_scale = I_buffer[:, :, 128:132].view(dtype=torch.float32) # Allocate output tensors - O_fp8 = torch.empty((req_size, head_dim), dtype=torch.float8_e4m3fn, device=I_buffer.device) - O_scale = torch.empty((req_size,), dtype=torch.float32, device=I_buffer.device) + O_fp8 = torch.empty((out_token_num, head_dim), dtype=torch.float8_e4m3fn, device=I_buffer.device) + O_scale = torch.empty((out_token_num,), dtype=torch.float32, device=I_buffer.device) - grid = (req_size,) + grid = (b_seq_len.shape[0], min(256, max_kv_seq_len)) num_warps = 1 _fwd_kernel_extract_indexer_ks[grid]( - I_buffer, - SrcLoc, - O_fp8, - O_scale, - I_buffer.stride(0), - I_buffer.stride(1), - I_buffer.stride(2), - O_fp8.stride(0), - O_fp8.stride(1), - O_scale.stride(0), + in_fp8, + stride_in_fp8_bs=in_fp8.stride(0), + stride_in_fp8_h=in_fp8.stride(1), + stride_in_fp8_d=in_fp8.stride(2), + in_fp8_scale=in_fp8_scale, + stride_in_scale_bs=in_fp8_scale.stride(0), + stride_in_scale_h=in_fp8_scale.stride(1), + stride_in_scale_d=in_fp8_scale.stride(2), + req_to_token_indexs=req_to_token_indexs, + stride_req_to_token_m=req_to_token_indexs.stride(0), + stride_req_to_token_n=req_to_token_indexs.stride(1), + b_seq_len=b_seq_len, + b_req_idx=b_req_idx, + b_cu_kv_seq_len=b_cu_kv_seq_len, + O_fp8=O_fp8, + stride_o_fp8_bs=O_fp8.stride(0), + stride_o_fp8_d=O_fp8.stride(1), + O_scale=O_scale, + stride_o_scale_bs=O_scale.stride(0), + stride_o_scale_d=O_scale.stride(1), BLOCK_DMODEL=head_dim, num_warps=num_warps, num_stages=1, From 68efbdcdc82a2b7d76b985cdddffd1ddc466ac20 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Tue, 3 Mar 2026 21:34:51 +0800 Subject: [PATCH 30/58] fix --- .../models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py b/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py index 8a16ab0dcf..9bbb5a68bc 100644 --- a/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py +++ b/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py @@ -1,4 +1,3 @@ -from sgl_kernel import fast_topk_transform_fused import deep_gemm import torch from lightllm.common.basemodel.layer_infer.base_layer_infer import BaseLayerInfer @@ -115,6 +114,8 @@ def get_indices( logits = deep_gemm.fp8_mqa_logits(q_fp8, (k_fp8_, k_scale_), weights.squeeze(-1), ks, ke) + from sgl_kernel import fast_topk_transform_fused + return fast_topk_transform_fused( score=logits, lengths=lengths, From 7d8e54d2561d01e1583e8b308257887aaf32834f Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Tue, 3 Mar 2026 22:26:11 +0800 Subject: [PATCH 31/58] fix --- .../triton_kernel/extract_indexer_ks.py | 25 +++++++++++++------ 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/lightllm/models/deepseek3_2/triton_kernel/extract_indexer_ks.py b/lightllm/models/deepseek3_2/triton_kernel/extract_indexer_ks.py index 5684eb4114..55b967e066 100644 --- a/lightllm/models/deepseek3_2/triton_kernel/extract_indexer_ks.py +++ b/lightllm/models/deepseek3_2/triton_kernel/extract_indexer_ks.py @@ -19,21 +19,28 @@ def _fwd_kernel_extract_indexer_ks( stride_req_to_token_n, b_seq_len, b_req_idx, - b_cu_kv_seq_len, O_fp8, stride_o_fp8_bs, stride_o_fp8_d, O_scale, stride_o_scale_bs, stride_o_scale_d, + mtp_step, BLOCK_DMODEL: tl.constexpr, + BLOCK_SEQ_LEN: tl.constexpr, ): - cur_req_index = tl.program_id(0) + origin_cur_req_index = tl.program_id(0) + cur_req_index = (origin_cur_req_index + 1) * (mtp_step + 1) - 1 token_start_index = tl.program_id(1) cur_req_idx = tl.load(b_req_idx + cur_req_index) cur_seq_len = tl.load(b_seq_len + cur_req_index) offs_d = tl.arange(0, BLOCK_DMODEL) - store_start_index = tl.load(b_cu_kv_seq_len + cur_req_index) + b_seq_len = tl.load( + b_seq_len + (tl.arange(0, BLOCK_SEQ_LEN) + 1) * (mtp_step + 1) - 1, + mask=tl.arange(0, BLOCK_SEQ_LEN) < origin_cur_req_index, + other=0, + ) + store_start_index = tl.sum(b_seq_len) for i in range(token_start_index, cur_seq_len, tl.num_programs(1)): mem_index = tl.load(req_to_token_indexs + cur_req_idx * stride_req_to_token_m + i * stride_req_to_token_n) @@ -58,10 +65,10 @@ def extract_indexer_ks( I_buffer: torch.Tensor, b_seq_len: torch.Tensor, b_req_idx: torch.Tensor, - b_cu_kv_seq_len: torch.Tensor, req_to_token_indexs: torch.Tensor, out_token_num: int, max_kv_seq_len: int, + mtp_step: int, ) -> tuple[torch.Tensor, torch.Tensor]: head_dim = 128 @@ -71,10 +78,11 @@ def extract_indexer_ks( in_fp8_scale = I_buffer[:, :, 128:132].view(dtype=torch.float32) # Allocate output tensors - O_fp8 = torch.empty((out_token_num, head_dim), dtype=torch.float8_e4m3fn, device=I_buffer.device) - O_scale = torch.empty((out_token_num,), dtype=torch.float32, device=I_buffer.device) + O_fp8 = torch.empty((out_token_num // (mtp_step + 1), head_dim), dtype=torch.float8_e4m3fn, device=I_buffer.device) + O_scale = torch.empty((out_token_num // (mtp_step + 1),), dtype=torch.float32, device=I_buffer.device) - grid = (b_seq_len.shape[0], min(256, max_kv_seq_len)) + assert b_seq_len.shape[0] % (mtp_step + 1) == 0 + grid = (b_seq_len.shape[0] // (mtp_step + 1), min(256, max_kv_seq_len)) num_warps = 1 _fwd_kernel_extract_indexer_ks[grid]( @@ -91,14 +99,15 @@ def extract_indexer_ks( stride_req_to_token_n=req_to_token_indexs.stride(1), b_seq_len=b_seq_len, b_req_idx=b_req_idx, - b_cu_kv_seq_len=b_cu_kv_seq_len, O_fp8=O_fp8, stride_o_fp8_bs=O_fp8.stride(0), stride_o_fp8_d=O_fp8.stride(1), O_scale=O_scale, stride_o_scale_bs=O_scale.stride(0), stride_o_scale_d=O_scale.stride(1), + mtp_step=mtp_step, BLOCK_DMODEL=head_dim, + BLOCK_SEQ_LEN=triton.next_power_of_2(b_seq_len.shape[0] // (mtp_step + 1)), num_warps=num_warps, num_stages=1, ) From 932b402f84ceaf1dcb8302a33538a6f8e393d1e5 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Tue, 3 Mar 2026 22:32:55 +0800 Subject: [PATCH 32/58] Fix --- .../models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py b/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py index 9bbb5a68bc..c3561562d6 100644 --- a/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py +++ b/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py @@ -94,10 +94,10 @@ def get_indices( I_buffer=infer_state.mem_manager.indexer_ks_buffer.kv_buffer[self.layer_idx_], b_seq_len=infer_state.b_seq_len, b_req_idx=infer_state.b_req_idx, - b_cu_kv_seq_len=infer_state.b1_cu_kv_seq_len, req_to_token_indexs=infer_state.req_manager.req_to_token_indexs, out_token_num=infer_state.b_seq_len.shape[0] * infer_state.max_kv_seq_len, max_kv_seq_len=infer_state.max_kv_seq_len, + mtp_step=0, ) # Get actual sequence length from q (which comes from q_lora) From f88503eaf12abffd1e29e6f76a15674ce6351612 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Wed, 4 Mar 2026 01:24:36 +0000 Subject: [PATCH 33/58] fix --- .../destindex_copy_indexer_ks.py | 68 +++++++++---------- 1 file changed, 31 insertions(+), 37 deletions(-) diff --git a/lightllm/models/deepseek3_2/triton_kernel/destindex_copy_indexer_ks.py b/lightllm/models/deepseek3_2/triton_kernel/destindex_copy_indexer_ks.py index a345bd1e20..a3115cc61c 100644 --- a/lightllm/models/deepseek3_2/triton_kernel/destindex_copy_indexer_ks.py +++ b/lightllm/models/deepseek3_2/triton_kernel/destindex_copy_indexer_ks.py @@ -9,14 +9,16 @@ def _fwd_kernel_destindex_copy_indexer_ks( K_fp8, K_scale, DestLoc, - O_buffer, stride_k_bs, stride_k_d, stride_scale_bs, stride_scale_d, + O_fp8, stride_o_bs, - stride_o_h, stride_o_d, + O_fp8_scale, + stride_o_scale_bs, + stride_o_scale_d, BLOCK_DMODEL: tl.constexpr, ): cur_index = tl.program_id(0) @@ -29,24 +31,13 @@ def _fwd_kernel_destindex_copy_indexer_ks( k_fp8_ptrs = K_fp8 + cur_index * stride_k_bs + stride_k_d * offs_d k_fp8 = tl.load(k_fp8_ptrs) - k_scale = tl.load(K_scale + cur_index * stride_scale_bs) + k_scale = tl.load(K_scale + cur_index * stride_scale_bs + stride_scale_d * 0) - # Store K_fp8 to O_buffer[:, 0, :128] - # Convert fp8 to uint8 through bitcast for storage in uint8 buffer - o_k_ptrs = O_buffer + dest_index * stride_o_bs + stride_o_d * offs_d - k_fp8_as_uint8 = k_fp8.to(tl.uint8, bitcast=True) - tl.store(o_k_ptrs, k_fp8_as_uint8) - - # Store K_scale to O_buffer[:, 0, 128:132] (4 bytes for float32) - # Convert float32 scale to 4 uint8 bytes using bitcast and bit manipulation - o_scale_ptr = O_buffer + dest_index * stride_o_bs + BLOCK_DMODEL * stride_o_d - scale_as_uint32 = k_scale.to(tl.float32, bitcast=True).to(tl.uint32, bitcast=True) - - # Store each byte of the float32 scale (little-endian) - for i in range(4): - byte_val = ((scale_as_uint32 >> (i * 8)) & 0xFF).to(tl.uint8) - tl.store(o_scale_ptr + i * stride_o_d, byte_val) + o_k_ptrs = O_fp8 + dest_index * stride_o_bs + stride_o_d * offs_d + tl.store(o_k_ptrs, k_fp8) + o_scale_ptr = O_fp8_scale + dest_index * stride_o_scale_bs + stride_o_scale_d * 0 + tl.store(o_scale_ptr, k_scale) return @@ -58,31 +49,34 @@ def destindex_copy_indexer_ks( head_dim = K_fp8.shape[1] assert head_dim == 128, f"Expected head_dim=128, got {head_dim}" + assert O_buffer.shape[2] == 132, f"Expected O_buffer last dim=132, got {O_buffer.shape[2]}" + assert K_fp8.shape[0] == seq_len, f"Expected K_fp8 shape[0]={seq_len}, got {K_fp8.shape[0]}" + K_fp8 = K_fp8.view(-1, head_dim) + K_scale = K_scale.view(-1, 1) - # Handle cases where tensor lengths don't match (e.g., during prefix cache) - actual_seq_len = min(K_scale.shape[0], seq_len) - if actual_seq_len < seq_len: - K_fp8 = K_fp8[:actual_seq_len] - K_scale = K_scale[:actual_seq_len] - DestLoc = DestLoc[:actual_seq_len] + assert K_fp8.shape[0] == seq_len, f"Expected K_fp8 shape[0]={seq_len}, got {K_fp8.shape[0]}" + assert K_scale.shape[0] == seq_len, f"Expected K_scale shape[0]={seq_len}, got {K_scale.shape[0]}" - assert O_buffer.shape[2] == 132, f"Expected O_buffer last dim=132, got {O_buffer.shape[2]}" + O_fp8 = O_buffer[:, :, :128].view(dtype=torch.float8_e4m3fn).view(-1, head_dim) + O_fp8_scale = O_buffer[:, :, 128:132].view(dtype=torch.float32).view(-1, 1) - grid = (actual_seq_len,) + grid = (seq_len,) num_warps = 1 _fwd_kernel_destindex_copy_indexer_ks[grid]( - K_fp8, - K_scale, - DestLoc, - O_buffer, - K_fp8.stride(0), - K_fp8.stride(1), - K_scale.stride(0), - K_scale.stride(1), - O_buffer.stride(0), - O_buffer.stride(1), - O_buffer.stride(2), + K_fp8=K_fp8, + K_scale=K_scale, + DestLoc=DestLoc, + stride_k_bs=K_fp8.stride(0), + stride_k_d=K_fp8.stride(1), + stride_scale_bs=K_scale.stride(0), + stride_scale_d=K_scale.stride(1), + O_fp8=O_fp8, + stride_o_bs=O_fp8.stride(0), + stride_o_d=O_fp8.stride(1), + O_fp8_scale=O_fp8_scale, + stride_o_scale_bs=O_fp8_scale.stride(0), + stride_o_scale_d=O_fp8_scale.stride(1), BLOCK_DMODEL=head_dim, num_warps=num_warps, num_stages=1, From 54f8d5cae08bfcc93746fee7a512c4dec26a1915 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Wed, 4 Mar 2026 01:36:45 +0000 Subject: [PATCH 34/58] fix --- .../layer_infer/nsa_indexer_layer_inder.py | 159 ------------------ .../layer_infer/transformer_layer_infer.py | 117 ++++++++++++- 2 files changed, 114 insertions(+), 162 deletions(-) delete mode 100644 lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py diff --git a/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py b/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py deleted file mode 100644 index c3561562d6..0000000000 --- a/lightllm/models/deepseek3_2/layer_infer/nsa_indexer_layer_inder.py +++ /dev/null @@ -1,159 +0,0 @@ -import deep_gemm -import torch -from lightllm.common.basemodel.layer_infer.base_layer_infer import BaseLayerInfer -from lightllm.models.deepseek3_2.layer_weights.nsa_indexer_layer_weight import NSAIndexerWeight -from lightllm.models.deepseek3_2._del_infer_struct import Deepseek3_2InferStateInfo -from lightllm.models.deepseek2.triton_kernel.rotary_emb import rotary_emb_fwd -from lightllm.models.deepseek3_2.triton_kernel.act_quant import act_quant -from lightllm.models.deepseek3_2.triton_kernel.destindex_copy_indexer_ks import destindex_copy_indexer_ks -from lightllm.models.deepseek3_2.triton_kernel.extract_indexer_ks import extract_indexer_ks -from lightllm.utils.log_utils import init_logger - -logger = init_logger(__name__) - - -class NSAIndexerInfer(BaseLayerInfer): - def __init__(self, layer_idx, network_config): - super().__init__() - self.layer_idx_ = layer_idx - self.network_config_ = network_config - self.index_topk = network_config["index_topk"] - self.tp_q_head_num_ = network_config["num_attention_heads"] // self.tp_world_size_ - self.tp_k_head_num_ = 1 - self.tp_v_head_num_ = 1 - self.qk_nope_head_dim = network_config["qk_nope_head_dim"] - self.qk_rope_head_dim = network_config["qk_rope_head_dim"] - self.index_head_dim = network_config["index_head_dim"] - self.eps = network_config["rms_norm_eps"] - self.block_size = network_config["quantization_config"]["weight_block_size"][0] - self.scale_fmt = network_config["quantization_config"]["scale_fmt"] - self.softmax_scale = (self.qk_nope_head_dim + self.qk_rope_head_dim) ** (-0.5) - self.index_n_heads = network_config["index_n_heads"] - self.index_n_heads_scale = (self.index_n_heads ** -0.5) * self.softmax_scale - - return - - def ref_fp8_mqa_logits( - self, - q: torch.Tensor, - kv: torch.Tensor, - weights: torch.Tensor, - cu_seqlen_ks: torch.Tensor, - cu_seqlen_ke: torch.Tensor, - cost_only: bool = False, - ): - seq_len_kv = kv.shape[0] - - if cost_only: - start = cu_seqlen_ks.clamp(min=0, max=seq_len_kv) - end = cu_seqlen_ke.clamp(min=0, max=seq_len_kv) - count_ones_per_row = (end - start).clamp(min=0) - return count_ones_per_row.sum() - - k = kv - q = q.float() - k = k.float() - - mask_lo = torch.arange(0, seq_len_kv, device="cuda")[None, :] >= cu_seqlen_ks[:, None] - mask_hi = torch.arange(0, seq_len_kv, device="cuda")[None, :] < cu_seqlen_ke[:, None] - mask = mask_lo & mask_hi - - score = torch.einsum("mhd,nd->hmn", q, k) - logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0) - logits = logits.masked_fill(~mask, float("-inf")) - - cost = mask.sum() - return logits, cost - - def get_indices( - self, - hidden_states: torch.Tensor, - q_lora: torch.Tensor, - infer_state: Deepseek3_2InferStateInfo, - layer_weight: NSAIndexerWeight, - ) -> torch.Tensor: - - q, k = self._get_q_k_bf16(hidden_states, q_lora, infer_state, layer_weight) - q_fp8, q_scale = act_quant(q, self.block_size, self.scale_fmt) - k_fp8, k_scale = act_quant(k, self.block_size, self.scale_fmt) - - destindex_copy_indexer_ks( - k_fp8, k_scale, infer_state.mem_index, infer_state.indexer_ks_buffer.kv_buffer[self.layer_idx_] - ) - - weights = layer_weight.weights_proj_.mm(hidden_states) * self.index_n_heads_scale - weights = weights.unsqueeze(-1) * q_scale - - ks = infer_state.ks - ke = infer_state.ke - lengths = infer_state.lengths - page_table_1 = infer_state.page_table_size_1 - - # Use efficient Triton kernel to extract FP8 keys and scales from buffer - k_fp8_, k_scale_ = extract_indexer_ks( - I_buffer=infer_state.mem_manager.indexer_ks_buffer.kv_buffer[self.layer_idx_], - b_seq_len=infer_state.b_seq_len, - b_req_idx=infer_state.b_req_idx, - req_to_token_indexs=infer_state.req_manager.req_to_token_indexs, - out_token_num=infer_state.b_seq_len.shape[0] * infer_state.max_kv_seq_len, - max_kv_seq_len=infer_state.max_kv_seq_len, - mtp_step=0, - ) - - # Get actual sequence length from q (which comes from q_lora) - # This may differ from ks.shape[0] during certain operations - actual_seq_len = q.shape[0] - - # ks, ke, lengths, and weights should all match actual_seq_len - # Slice them if they don't match - if ks.shape[0] != actual_seq_len: - ks = ks[:actual_seq_len] - ke = ke[:actual_seq_len] - lengths = lengths[:actual_seq_len] - weights = weights[:actual_seq_len] - - logits = deep_gemm.fp8_mqa_logits(q_fp8, (k_fp8_, k_scale_), weights.squeeze(-1), ks, ke) - - from sgl_kernel import fast_topk_transform_fused - - return fast_topk_transform_fused( - score=logits, - lengths=lengths, - page_table_size_1=page_table_1, - cu_seqlens_q=infer_state.b1_cu_q_seq_len, - topk=self.index_topk, - ) - - @staticmethod - def _rotate_activation(x: torch.Tensor) -> torch.Tensor: - assert x.dtype == torch.bfloat16 - from sgl_kernel import hadamard_transform - - hidden_size = x.size(-1) - assert (hidden_size & (hidden_size - 1)) == 0, "Hidden size must be a power of 2 for Hadamard transform." - return hadamard_transform(x, scale=hidden_size ** -0.5) - - def _get_q_k_bf16( - self, - hidden_states: torch.Tensor, - q_lora: torch.Tensor, - infer_state: Deepseek3_2InferStateInfo, - layer_weight: NSAIndexerWeight, - ): - q = layer_weight.wq_b_proj_.mm(q_lora).view(-1, self.index_n_heads, self.index_head_dim) - k = layer_weight.wk_proj_.mm(hidden_states) - - k = layer_weight.k_norm_(k, eps=self.eps) - - # Slice position_cos and position_sin to match actual token length - actual_seq_len = q.shape[0] - rotary_emb_fwd( - q[:, :, : self.qk_rope_head_dim], - k[:, None, : self.qk_rope_head_dim], - infer_state.position_cos[:actual_seq_len], - infer_state.position_sin[:actual_seq_len], - ) - - q = self._rotate_activation(q) - k = self._rotate_activation(k) - return q, k diff --git a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py index 9a772c17ff..af325c0de4 100644 --- a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py @@ -1,13 +1,14 @@ import torch from lightllm.models.deepseek2.layer_infer.transformer_layer_infer import Deepseek2TransformerLayerInfer -from lightllm.models.deepseek3_2.layer_infer.nsa_indexer_layer_inder import NSAIndexerInfer from lightllm.models.deepseek3_2.layer_weights.transformer_layer_weight import Deepseek3_2TransformerLayerWeight from lightllm.models.deepseek3_2._del_infer_struct import Deepseek3_2InferStateInfo -from lightllm.models.deepseek3_2.triton_kernel.token_group_quant import per_token_group_quant_mla_deep_gemm_masked_fp8 from lightllm.common.basemodel.triton_kernel.norm.rmsnorm import rmsnorm_forward from lightllm.models.deepseek2.triton_kernel.rotary_emb import rotary_emb_fwd from lightllm.common.basemodel.attention.base_att import AttControl +from lightllm.models.deepseek3_2.triton_kernel.act_quant import act_quant +from lightllm.models.deepseek3_2.triton_kernel.destindex_copy_indexer_ks import destindex_copy_indexer_ks +from lightllm.models.deepseek3_2.triton_kernel.extract_indexer_ks import extract_indexer_ks class Deepseek3_2TransformerLayerInfer(Deepseek2TransformerLayerInfer): @@ -15,7 +16,7 @@ def __init__(self, layer_num, network_config): self.index_topk = network_config["index_topk"] super().__init__(layer_num, network_config) - self.indexer = NSAIndexerInfer(layer_idx=self.layer_num_, network_config=self.network_config_) + self.indexer = NsaInfer(layer_idx=self.layer_num_, network_config=self.network_config_) self.topk_indices = None return @@ -119,3 +120,113 @@ def _token_attention_kernel( att_control=att_control, ) return o_tensor + + +class NsaInfer: + def __init__(self, layer_idx: int, network_config: dict): + super().__init__() + self.layer_idx_ = layer_idx + self.network_config_ = network_config + self.index_topk = network_config["index_topk"] + self.qk_nope_head_dim = network_config["qk_nope_head_dim"] + self.qk_rope_head_dim = network_config["qk_rope_head_dim"] + self.index_head_dim = network_config["index_head_dim"] + self.eps = network_config["rms_norm_eps"] + self.block_size = network_config["quantization_config"]["weight_block_size"][0] + self.scale_fmt = network_config["quantization_config"]["scale_fmt"] + self.softmax_scale = (self.qk_nope_head_dim + self.qk_rope_head_dim) ** (-0.5) + self.index_n_heads = network_config["index_n_heads"] + self.index_n_heads_scale = (self.index_n_heads ** -0.5) * self.softmax_scale + + def get_indices( + self, + hidden_states: torch.Tensor, + q_lora: torch.Tensor, + infer_state: Deepseek3_2InferStateInfo, + layer_weight: Deepseek3_2TransformerLayerWeight, + ) -> torch.Tensor: + + q, k = self._get_q_k_bf16(hidden_states, q_lora, infer_state, layer_weight) + q_fp8, q_scale = act_quant(q, self.block_size, self.scale_fmt) + k_fp8, k_scale = act_quant(k, self.block_size, self.scale_fmt) + + destindex_copy_indexer_ks( + k_fp8, k_scale, infer_state.mem_index, infer_state.mem_manager.indexer_ks_buffer.kv_buffer[self.layer_idx_] + ) + + weights = layer_weight.weights_proj_.mm(hidden_states) * self.index_n_heads_scale + weights = weights.unsqueeze(-1) * q_scale + + ks = infer_state.ks + ke = infer_state.ke + lengths = infer_state.lengths + page_table_1 = infer_state.page_table_size_1 + + # Use efficient Triton kernel to extract FP8 keys and scales from buffer + k_fp8_, k_scale_ = extract_indexer_ks( + I_buffer=infer_state.mem_manager.indexer_ks_buffer.kv_buffer[self.layer_idx_], + b_seq_len=infer_state.b_seq_len, + b_req_idx=infer_state.b_req_idx, + req_to_token_indexs=infer_state.req_manager.req_to_token_indexs, + out_token_num=infer_state.b_seq_len.shape[0] * infer_state.max_kv_seq_len, + max_kv_seq_len=infer_state.max_kv_seq_len, + mtp_step=0, + ) + + # Get actual sequence length from q (which comes from q_lora) + # This may differ from ks.shape[0] during certain operations + actual_seq_len = q.shape[0] + + # ks, ke, lengths, and weights should all match actual_seq_len + # Slice them if they don't match + if ks.shape[0] != actual_seq_len: + ks = ks[:actual_seq_len] + ke = ke[:actual_seq_len] + lengths = lengths[:actual_seq_len] + weights = weights[:actual_seq_len] + + import deep_gemm + + logits = deep_gemm.fp8_mqa_logits(q_fp8, (k_fp8_, k_scale_), weights.squeeze(-1), ks, ke) + + from sgl_kernel import fast_topk_transform_fused + + return fast_topk_transform_fused( + score=logits, + lengths=lengths, + page_table_size_1=page_table_1, + cu_seqlens_q=infer_state.b1_cu_q_seq_len, + topk=self.index_topk, + ) + + @staticmethod + def _rotate_activation(x: torch.Tensor) -> torch.Tensor: + assert x.dtype == torch.bfloat16 + from sgl_kernel import hadamard_transform + + hidden_size = x.size(-1) + assert (hidden_size & (hidden_size - 1)) == 0, "Hidden size must be a power of 2 for Hadamard transform." + return hadamard_transform(x, scale=hidden_size ** -0.5) + + def _get_q_k_bf16( + self, + hidden_states: torch.Tensor, + q_lora: torch.Tensor, + infer_state: Deepseek3_2InferStateInfo, + layer_weight: Deepseek3_2TransformerLayerWeight, + ): + q = layer_weight.wq_b_proj_.mm(q_lora).view(-1, self.index_n_heads, self.index_head_dim) + k = layer_weight.wk_proj_.mm(hidden_states) + + k = layer_weight.k_norm_(k, eps=self.eps) + + rotary_emb_fwd( + q[:, :, : self.qk_rope_head_dim], + k[:, None, : self.qk_rope_head_dim], + infer_state.position_cos, + infer_state.position_sin, + ) + + q = self._rotate_activation(q) + k = self._rotate_activation(k) + return q, k From bc618f2564a8da5171713eb453502a5f9deded64 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Wed, 4 Mar 2026 01:39:19 +0000 Subject: [PATCH 35/58] fix --- .../layer_infer/transformer_layer_infer.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py index af325c0de4..393fa1903f 100644 --- a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py @@ -17,7 +17,6 @@ def __init__(self, layer_num, network_config): super().__init__(layer_num, network_config) self.indexer = NsaInfer(layer_idx=self.layer_num_, network_config=self.network_config_) - self.topk_indices = None return def _get_qkv( @@ -33,7 +32,7 @@ def _get_qkv( ) q = rmsnorm_forward(q, weight=layer_weight.q_a_layernorm_.weight, eps=self.eps_) - self.topk_indices = self.indexer.get_indices(input, q, infer_state, layer_weight.indexer_layer_weight) + infer_state.topk_indices = self.indexer.get_indices(input, q, infer_state, layer_weight.indexer_layer_weight) q = layer_weight.q_b_proj_.mm(q) cache_kv = cache_kv.view(-1, 1, self.kv_lora_rank + self.qk_rope_head_dim) @@ -71,12 +70,14 @@ def _context_attention_kernel( att_control = AttControl( nsa_prefill=True, nsa_prefill_dict={ - "topk_indices": self.topk_indices, + "topk_indices": infer_state.topk_indices, "softmax_scale": self.softmax_scale, "kv_lora_rank": self.kv_lora_rank, }, ) + del infer_state.topk_indices + mla_out = infer_state.prefill_att_state.prefill_att( q=q_all, k=infer_state.mem_manager.kv_buffer[self.layer_num_], @@ -100,7 +101,7 @@ def _token_attention_kernel( att_control = AttControl( nsa_decode=True, nsa_decode_dict={ - "topk_indices": self.topk_indices, + "topk_indices": infer_state.topk_indices, "nsa_cache_seqlens": infer_state.nsa_cache_seqlens, "nsa_cu_seqlens_k": infer_state.nsa_cu_seqlens_k, "softmax_scale": self.softmax_scale, @@ -109,11 +110,9 @@ def _token_attention_kernel( }, ) - # Create decode state and execute attention - nsa_backend = self._get_nsa_backend() - decode_state = nsa_backend.create_att_decode_state(infer_state) - decode_state.init_state() - o_tensor = decode_state.decode_att( + del infer_state.topk_indices + + o_tensor = infer_state.decode_att_state.decode_att( q=(q_nope, q_rope), k=infer_state.mem_manager.kv_buffer[self.layer_num_], v=None, From c196bcafb8b62f9cdecdfa1a4adfbf720c9c383c Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Wed, 4 Mar 2026 01:40:53 +0000 Subject: [PATCH 36/58] fix --- .../models/deepseek3_2/layer_infer/transformer_layer_infer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py index 393fa1903f..b773652e7c 100644 --- a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py @@ -32,7 +32,7 @@ def _get_qkv( ) q = rmsnorm_forward(q, weight=layer_weight.q_a_layernorm_.weight, eps=self.eps_) - infer_state.topk_indices = self.indexer.get_indices(input, q, infer_state, layer_weight.indexer_layer_weight) + infer_state.topk_indices = self.indexer.get_indices(input, q, infer_state, layer_weight) q = layer_weight.q_b_proj_.mm(q) cache_kv = cache_kv.view(-1, 1, self.kv_lora_rank + self.qk_rope_head_dim) From 3de599bf9348bad4f7be8fa727e33a0176274b68 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Wed, 4 Mar 2026 01:51:41 +0000 Subject: [PATCH 37/58] fix --- .../layer_weights/transformer_layer_weight.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/lightllm/common/basemodel/layer_weights/transformer_layer_weight.py b/lightllm/common/basemodel/layer_weights/transformer_layer_weight.py index 0f4d6b13ae..86a887a259 100644 --- a/lightllm/common/basemodel/layer_weights/transformer_layer_weight.py +++ b/lightllm/common/basemodel/layer_weights/transformer_layer_weight.py @@ -36,20 +36,11 @@ def load_hf_weights(self, weights): """ for attr_name in dir(self): attr = getattr(self, attr_name, None) - if isinstance(attr, TransformerLayerWeight): - attr.load_hf_weights(weights) - elif isinstance(attr, MMWeightTpl) and len(attr.weight_names) >= 2: + if isinstance(attr, MMWeightTpl) and len(attr.weight_names) >= 2: with self.lock: attr.load_hf_weights(weights) elif isinstance(attr, BaseWeight): attr.load_hf_weights(weights) - def verify_load(self): - for attr_name in dir(self): - attr = getattr(self, attr_name, None) - if isinstance(attr, TransformerLayerWeight): - attr.verify_load() - super().verify_load() - def get_quant_method(self, name): return self.quant_cfg.get_quant_method(self.layer_num_, name) From 2fb07282af6593f4e3aebe077dac80585320a2bb Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Wed, 4 Mar 2026 04:12:45 +0000 Subject: [PATCH 38/58] fix --- .../basemodel/triton_kernel/gen_nsa_ks_ke.py | 76 +++++++++++++++++++ .../triton_kernel/test_gen_nsa_ks_ke.py | 61 +++++++++++++++ 2 files changed, 137 insertions(+) create mode 100644 lightllm/common/basemodel/triton_kernel/gen_nsa_ks_ke.py create mode 100644 unit_tests/common/basemodel/triton_kernel/test_gen_nsa_ks_ke.py diff --git a/lightllm/common/basemodel/triton_kernel/gen_nsa_ks_ke.py b/lightllm/common/basemodel/triton_kernel/gen_nsa_ks_ke.py new file mode 100644 index 0000000000..d05cca2c1b --- /dev/null +++ b/lightllm/common/basemodel/triton_kernel/gen_nsa_ks_ke.py @@ -0,0 +1,76 @@ +import torch +import torch.nn.functional as F +import triton +import triton.language as tl +from typing import Optional + + +@triton.jit +def _gen_nsa_ks_ke( + b_seq_len, + b_q_seq_len, + b_same_req_mark, + ks, + ke, + BLOCK_REQ: tl.constexpr, + BLOCK_SEQ_SPLIT: tl.constexpr, +): + cur_index = tl.program_id(0) + # 不处于边界mark的最后一个req不进行处理。 + req_mark = tl.load(b_same_req_mark + cur_index) + if req_mark == 0: + return + + off = tl.arange(0, BLOCK_REQ) + b_same_req_mark = tl.load(b_same_req_mark + off, off < cur_index, other=0) + b_seq_len_data = tl.load(b_seq_len + off, (off < cur_index) & (b_same_req_mark != 0), other=0) + pre_sum_seq_len = tl.sum(b_seq_len_data) + + # 兼容 prefill 和 decode 的情况, decode 可能存在 mtp 的情况,各个请求会共享一个req对象,其处理比较特殊 + q_seq_len = tl.load(b_q_seq_len + cur_index) + req_mark - 1 + cur_total_len = tl.load(b_seq_len + cur_index) + + b_q_seq_len_data = tl.load(b_q_seq_len + off, (off < (cur_index - req_mark + 1)), other=0) + store_start_index = tl.sum(b_q_seq_len_data) + + for block_index in range(tl.cdiv(q_seq_len, BLOCK_SEQ_SPLIT)): + block_start = block_index * BLOCK_SEQ_SPLIT + block_end = min(q_seq_len, (block_index + 1) * BLOCK_SEQ_SPLIT) + ks_data = tl.zeros((BLOCK_SEQ_SPLIT,), dtype=tl.int32) + ke_data = (cur_total_len - q_seq_len) + tl.arange(0, BLOCK_SEQ_SPLIT) + + tl.store( + ks + store_start_index + block_start + tl.arange(0, BLOCK_SEQ_SPLIT), + ks_data + pre_sum_seq_len, + mask=block_start + tl.arange(0, BLOCK_SEQ_SPLIT) < block_end, + ) + tl.store( + ke + store_start_index + block_start + tl.arange(0, BLOCK_SEQ_SPLIT), + ke_data + pre_sum_seq_len, + mask=block_start + tl.arange(0, BLOCK_SEQ_SPLIT) < block_end, + ) + + return + + +@torch.no_grad() +def gen_nsa_ks_ke( + b_seq_len: torch.Tensor, + b_q_seq_len: torch.Tensor, + b_same_req_mark: torch.Tensor, + q_token_num: int, +): + batch_size = b_seq_len.shape[0] + ks = torch.empty((q_token_num,), dtype=torch.int32, device=b_seq_len.device) + ke = torch.empty((q_token_num,), dtype=torch.int32, device=b_seq_len.device) + + _gen_nsa_ks_ke[(batch_size,)]( + b_seq_len=b_seq_len, + b_q_seq_len=b_q_seq_len, + b_same_req_mark=b_same_req_mark, + ks=ks, + ke=ke, + BLOCK_REQ=triton.next_power_of_2(batch_size), + BLOCK_SEQ_SPLIT=256, + ) + return ks, ke diff --git a/unit_tests/common/basemodel/triton_kernel/test_gen_nsa_ks_ke.py b/unit_tests/common/basemodel/triton_kernel/test_gen_nsa_ks_ke.py new file mode 100644 index 0000000000..29adad6818 --- /dev/null +++ b/unit_tests/common/basemodel/triton_kernel/test_gen_nsa_ks_ke.py @@ -0,0 +1,61 @@ +import torch +import pytest +from lightllm.common.basemodel.triton_kernel.gen_nsa_ks_ke import gen_nsa_ks_ke + + +def test_gen_nsa_ks_ke_basic(): + """Test basic functionality of gen_nsa_ks_ke with simple inputs.""" + # Setup test data + b_seq_len = torch.tensor( + [ + 10, + ], + dtype=torch.int32, + device="cuda", + ) + b_q_seq_len = torch.tensor( + [ + 5, + ], + dtype=torch.int32, + device="cuda", + ) + b_same_req_mark = torch.tensor( + [ + 1, + ], + dtype=torch.int32, + device="cuda", + ) + q_token_num = b_q_seq_len.sum().item() + + ks, ke = gen_nsa_ks_ke(b_seq_len, b_q_seq_len, b_same_req_mark, q_token_num) + + assert torch.equal(ks, torch.tensor([0, 0, 0, 0, 0], dtype=torch.int32, device="cuda")) + assert torch.equal(ke, torch.tensor([5, 6, 7, 8, 9], dtype=torch.int32, device="cuda")) + + +def test_gen_nsa_ks_ke_batch(): + b_seq_len = torch.tensor([10, 11], dtype=torch.int32, device="cuda") + b_q_seq_len = torch.tensor([1, 1], dtype=torch.int32, device="cuda") + b_same_req_mark = torch.tensor([0, 2], dtype=torch.int32, device="cuda") + q_token_num = b_q_seq_len.sum().item() + + ks, ke = gen_nsa_ks_ke(b_seq_len, b_q_seq_len, b_same_req_mark, q_token_num) + + assert torch.equal( + ks, + torch.tensor( + [ + 0, + 0, + ], + dtype=torch.int32, + device="cuda", + ), + ) + assert torch.equal(ke, torch.tensor([9, 10], dtype=torch.int32, device="cuda")) + + +if __name__ == "__main__": + pytest.main() From 51d7e3750bf4334ed630a344160623dd6f771090 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Wed, 4 Mar 2026 04:15:47 +0000 Subject: [PATCH 39/58] fix --- .../deepseek3_2/layer_infer/transformer_layer_infer.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py index b773652e7c..eb2c520330 100644 --- a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py @@ -9,6 +9,7 @@ from lightllm.models.deepseek3_2.triton_kernel.act_quant import act_quant from lightllm.models.deepseek3_2.triton_kernel.destindex_copy_indexer_ks import destindex_copy_indexer_ks from lightllm.models.deepseek3_2.triton_kernel.extract_indexer_ks import extract_indexer_ks +from lightllm.utils.envs_utils import get_env_start_args class Deepseek3_2TransformerLayerInfer(Deepseek2TransformerLayerInfer): @@ -161,6 +162,10 @@ def get_indices( lengths = infer_state.lengths page_table_1 = infer_state.page_table_size_1 + if infer_state.is_prefill: + mtp_step = 0 + else: + mtp_step = get_env_start_args().mtp_step # Use efficient Triton kernel to extract FP8 keys and scales from buffer k_fp8_, k_scale_ = extract_indexer_ks( I_buffer=infer_state.mem_manager.indexer_ks_buffer.kv_buffer[self.layer_idx_], @@ -169,7 +174,7 @@ def get_indices( req_to_token_indexs=infer_state.req_manager.req_to_token_indexs, out_token_num=infer_state.b_seq_len.shape[0] * infer_state.max_kv_seq_len, max_kv_seq_len=infer_state.max_kv_seq_len, - mtp_step=0, + mtp_step=mtp_step, ) # Get actual sequence length from q (which comes from q_lora) From 2f6b3a5b06360b4d6fead2b458e9e24d82070b3e Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Wed, 4 Mar 2026 04:48:02 +0000 Subject: [PATCH 40/58] fix --- .../basemodel/triton_kernel/gen_nsa_ks_ke.py | 31 ++++++++++++++++++- .../triton_kernel/test_gen_nsa_ks_ke.py | 19 +++++++++--- 2 files changed, 44 insertions(+), 6 deletions(-) diff --git a/lightllm/common/basemodel/triton_kernel/gen_nsa_ks_ke.py b/lightllm/common/basemodel/triton_kernel/gen_nsa_ks_ke.py index d05cca2c1b..ec30f176e5 100644 --- a/lightllm/common/basemodel/triton_kernel/gen_nsa_ks_ke.py +++ b/lightllm/common/basemodel/triton_kernel/gen_nsa_ks_ke.py @@ -57,12 +57,13 @@ def _gen_nsa_ks_ke( def gen_nsa_ks_ke( b_seq_len: torch.Tensor, b_q_seq_len: torch.Tensor, - b_same_req_mark: torch.Tensor, + b_req_idx: torch.Tensor, q_token_num: int, ): batch_size = b_seq_len.shape[0] ks = torch.empty((q_token_num,), dtype=torch.int32, device=b_seq_len.device) ke = torch.empty((q_token_num,), dtype=torch.int32, device=b_seq_len.device) + b_same_req_mark = gen_same_req_mark(b_req_idx) _gen_nsa_ks_ke[(batch_size,)]( b_seq_len=b_seq_len, @@ -74,3 +75,31 @@ def gen_nsa_ks_ke( BLOCK_SEQ_SPLIT=256, ) return ks, ke + + +@triton.jit +def _gen_same_req_mark(b_req_idx, b_same_req_mark, BLOCK_SIZE: tl.constexpr): + cur_index = tl.program_id(0) + cur_req_idx = tl.load(b_req_idx + cur_index) + off = tl.arange(0, BLOCK_SIZE) + pre_req_idxs = tl.load(b_req_idx + off, (off < cur_index) & (off < tl.num_programs(0)), other=-1) + after_req_idxs = tl.load(b_req_idx + off, (off > cur_index) & (off < tl.num_programs(0)), other=-1) + + pre_idx_count = tl.sum(pre_req_idxs == cur_req_idx) + after_idx_count = tl.sum(after_req_idxs == cur_req_idx) + + has_mark = tl.where(after_idx_count == 0, 1, 0) + for _ in range(has_mark): + tl.store(b_same_req_mark + cur_index, pre_idx_count + 1) + return + + +@torch.no_grad() +def gen_same_req_mark(b_req_idx: torch.Tensor): + batch_size = b_req_idx.shape[0] + b_same_req_mark = torch.empty((batch_size,), dtype=torch.int32, device=b_req_idx.device) + b_same_req_mark.fill_(0) + _gen_same_req_mark[(batch_size,)]( + b_req_idx=b_req_idx, b_same_req_mark=b_same_req_mark, BLOCK_SIZE=triton.next_power_of_2(batch_size) + ) + return b_same_req_mark diff --git a/unit_tests/common/basemodel/triton_kernel/test_gen_nsa_ks_ke.py b/unit_tests/common/basemodel/triton_kernel/test_gen_nsa_ks_ke.py index 29adad6818..8fcdc0e51a 100644 --- a/unit_tests/common/basemodel/triton_kernel/test_gen_nsa_ks_ke.py +++ b/unit_tests/common/basemodel/triton_kernel/test_gen_nsa_ks_ke.py @@ -1,6 +1,6 @@ import torch import pytest -from lightllm.common.basemodel.triton_kernel.gen_nsa_ks_ke import gen_nsa_ks_ke +from lightllm.common.basemodel.triton_kernel.gen_nsa_ks_ke import gen_nsa_ks_ke, gen_same_req_mark def test_gen_nsa_ks_ke_basic(): @@ -20,7 +20,7 @@ def test_gen_nsa_ks_ke_basic(): dtype=torch.int32, device="cuda", ) - b_same_req_mark = torch.tensor( + b_req_idx = torch.tensor( [ 1, ], @@ -29,7 +29,7 @@ def test_gen_nsa_ks_ke_basic(): ) q_token_num = b_q_seq_len.sum().item() - ks, ke = gen_nsa_ks_ke(b_seq_len, b_q_seq_len, b_same_req_mark, q_token_num) + ks, ke = gen_nsa_ks_ke(b_seq_len, b_q_seq_len, b_req_idx, q_token_num) assert torch.equal(ks, torch.tensor([0, 0, 0, 0, 0], dtype=torch.int32, device="cuda")) assert torch.equal(ke, torch.tensor([5, 6, 7, 8, 9], dtype=torch.int32, device="cuda")) @@ -38,10 +38,10 @@ def test_gen_nsa_ks_ke_basic(): def test_gen_nsa_ks_ke_batch(): b_seq_len = torch.tensor([10, 11], dtype=torch.int32, device="cuda") b_q_seq_len = torch.tensor([1, 1], dtype=torch.int32, device="cuda") - b_same_req_mark = torch.tensor([0, 2], dtype=torch.int32, device="cuda") + b_req_idx = torch.tensor([1, 1], dtype=torch.int32, device="cuda") q_token_num = b_q_seq_len.sum().item() - ks, ke = gen_nsa_ks_ke(b_seq_len, b_q_seq_len, b_same_req_mark, q_token_num) + ks, ke = gen_nsa_ks_ke(b_seq_len, b_q_seq_len, b_req_idx, q_token_num) assert torch.equal( ks, @@ -57,5 +57,14 @@ def test_gen_nsa_ks_ke_batch(): assert torch.equal(ke, torch.tensor([9, 10], dtype=torch.int32, device="cuda")) +def test_gen_same_req_mark(): + b_req_idx = torch.tensor([0, 0, 1, 1, 1, 2], dtype=torch.int32, device="cuda") + expected_same_req_mark = torch.tensor([0, 2, 0, 0, 3, 1], dtype=torch.int32, device="cuda") + + same_req_mark = gen_same_req_mark(b_req_idx) + + assert torch.equal(same_req_mark, expected_same_req_mark) + + if __name__ == "__main__": pytest.main() From a37dc86fcf005f8ecea616ef66c1fae59aff2f61 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Wed, 4 Mar 2026 06:06:26 +0000 Subject: [PATCH 41/58] fix --- .../models/deepseek3_2/_del_infer_struct.py | 209 ------------------ .../layer_infer/transformer_layer_infer.py | 11 +- .../triton_kernel/topk_index_to_mem_index.py | 47 ++++ .../test_topk_index_to_mem_index.py | 24 ++ 4 files changed, 77 insertions(+), 214 deletions(-) delete mode 100644 lightllm/models/deepseek3_2/_del_infer_struct.py create mode 100644 lightllm/models/deepseek3_2/triton_kernel/topk_index_to_mem_index.py create mode 100644 unit_tests/models/deepseek3_2/triton_kernel/test_topk_index_to_mem_index.py diff --git a/lightllm/models/deepseek3_2/_del_infer_struct.py b/lightllm/models/deepseek3_2/_del_infer_struct.py deleted file mode 100644 index 779c2fc2d2..0000000000 --- a/lightllm/models/deepseek3_2/_del_infer_struct.py +++ /dev/null @@ -1,209 +0,0 @@ -import torch -import weakref -from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo -from lightllm.models.deepseek3_2.mem_manager import Deepseek3_2MemoryManager - - -class Deepseek3_2InferStateInfo(Deepseek2InferStateInfo): - _shared_nsa_buffers = None - - def __init__(self): - super().__init__() - self.lengths = None - self.page_table_size_1 = None - self.ks = None - self.ke = None - self.nsa_cu_seqlens_k = None - self.index_topk = 2048 - return - - @classmethod - def get_nsa_buffers(cls, graph_max_batch_size: int, max_seq_len: int): - """Get or create pre-allocated buffers for CUDA graph execution""" - if cls._shared_nsa_buffers is None: - max_total_q_tokens = graph_max_batch_size * max_seq_len - max_total_tokens = graph_max_batch_size * max_seq_len - - cls._shared_nsa_buffers = [ - { - "ks": torch.empty(max_total_q_tokens, dtype=torch.int, device="cuda"), - "ke": torch.empty(max_total_q_tokens, dtype=torch.int, device="cuda"), - "lengths": torch.empty(max_total_q_tokens, dtype=torch.int, device="cuda"), - "page_table_size_1": torch.empty(graph_max_batch_size, max_seq_len, dtype=torch.int, device="cuda"), - "req_all_mem_index": torch.empty(max_total_tokens, dtype=torch.int64, device="cuda"), - "nsa_cache_seqlens": torch.empty(graph_max_batch_size, dtype=torch.int32, device="cuda"), - "nsa_cu_seqlens_k": torch.empty(graph_max_batch_size + 1, dtype=torch.int32, device="cuda"), - }, - { - "ks": torch.empty(max_total_q_tokens, dtype=torch.int, device="cuda"), - "ke": torch.empty(max_total_q_tokens, dtype=torch.int, device="cuda"), - "lengths": torch.empty(max_total_q_tokens, dtype=torch.int, device="cuda"), - "page_table_size_1": torch.empty(graph_max_batch_size, max_seq_len, dtype=torch.int, device="cuda"), - "req_all_mem_index": torch.empty(max_total_tokens, dtype=torch.int64, device="cuda"), - "nsa_cache_seqlens": torch.empty(graph_max_batch_size, dtype=torch.int32, device="cuda"), - "nsa_cu_seqlens_k": torch.empty(graph_max_batch_size + 1, dtype=torch.int32, device="cuda"), - }, - ] - return cls._shared_nsa_buffers - - def _check_use_cuda_graph_buffers(self): - if hasattr(self, "_model_ref"): - model = self._model_ref() - if ( - model is not None - and hasattr(model, "graph_max_batch_size") - and hasattr(model, "graph_max_len_in_batch") - and self.batch_size <= model.graph_max_batch_size - and self.max_kv_seq_len <= model.graph_max_len_in_batch - ): - return True - return False - - def init_some_extra_state(self, model): - super().init_some_extra_state(model) - - self._model_ref = weakref.ref(model) - - assert isinstance(self.mem_manager, Deepseek3_2MemoryManager) - self.indexer_ks_buffer = self.mem_manager.indexer_ks_buffer - - if self.is_prefill: - self._init_nsa_indexing_prefill() - else: - if self.b_ready_cache_len is None: - self.b_ready_cache_len = torch.zeros_like(self.b_seq_len) - - use_cuda_graph_buffers = self._check_use_cuda_graph_buffers() - buffer = None - - if use_cuda_graph_buffers: - buffers = self.get_nsa_buffers(model.graph_max_batch_size, model.graph_max_len_in_batch) - buffer = buffers[self.microbatch_index] - self.nsa_cache_seqlens = buffer["nsa_cache_seqlens"][: self.batch_size] - self.nsa_cu_seqlens_k = buffer["nsa_cu_seqlens_k"][: self.batch_size + 1] - else: - self.nsa_cache_seqlens = torch.empty(self.batch_size, dtype=torch.int32, device="cuda") - self.nsa_cu_seqlens_k = torch.empty(self.batch_size + 1, dtype=torch.int32, device="cuda") - - self.nsa_cache_seqlens.copy_(self.b_kv_seq_len.clamp(max=self.index_topk)) - assert self.nsa_cache_seqlens.dtype == torch.int32 - - torch.cumsum(self.nsa_cache_seqlens, dim=0, dtype=torch.int32, out=self.nsa_cu_seqlens_k[1:]) - self.nsa_cu_seqlens_k[0] = 0 - - self._init_nsa_indexing_decode(use_cuda_graph_buffers, buffer) - - def _init_nsa_indexing_decode(self, use_cuda_graph_buffers, buffer): - """Optimized NSA indexing for decode: b_q_seq_len=1 per request. - - In decode, each request generates exactly 1 token, so: - - total_q_len = batch_size (no .item() needed) - - ks[i] = cumsum_offset[i], ke[i] = cumsum_offset[i] + 1 - - lengths[i] = b_seq_len[i] - - No repeat_interleave, no token_in_req math needed. - """ - b_seq_len = self.b_seq_len - b_req_idx = self.b_req_idx - num_seq = self.batch_size - - # Cumulative seq_len offsets for ks/ke: [0, s0, s0+s1, ...] - cum_seq = torch.cumsum(b_seq_len, dim=0, dtype=torch.int32) - - if use_cuda_graph_buffers: - model = self._model_ref() - max_seq_len = model.graph_max_len_in_batch - - # ks, ke, lengths — write directly into buffer slices - buf_ks = buffer["ks"][:num_seq] - buf_ke = buffer["ke"][:num_seq] - buf_lengths = buffer["lengths"][:num_seq] - - # ks[0] = 0, ks[i] = cum_seq[i-1] - buf_ks[0] = 0 - if num_seq > 1: - buf_ks[1:].copy_(cum_seq[: num_seq - 1]) - # ke = ks + 1 - torch.add(buf_ks, 1, out=buf_ke) - # lengths = b_seq_len - buf_lengths.copy_(b_seq_len.int()) - - self.ks = buf_ks - self.ke = buf_ke - self.lengths = buf_lengths - - # page_table: zero buffer slice, then fill valid entries - page_table = buffer["page_table_size_1"][:num_seq, :max_seq_len] - page_table.zero_() - all_rows = self.req_manager.req_to_token_indexs[b_req_idx, :max_seq_len] - seq_range = torch.arange(max_seq_len, device=b_seq_len.device) - valid_mask = seq_range.unsqueeze(0) < b_seq_len.unsqueeze(1) - page_table[valid_mask] = all_rows[valid_mask].int() - self.page_table_size_1 = page_table - - # req_all_mem_index: use padded [num_seq * max_seq_len] layout - # Downstream uses ks/ke masking so padded entries are safe - max_total_seq = num_seq * max_seq_len - buf_mem = buffer["req_all_mem_index"][:max_total_seq] - buf_mem.copy_(all_rows.reshape(-1)) - self.req_all_mem_index = buf_mem - else: - # Non-CUDA-graph decode: simplified formulas, fresh tensors - max_seq_len = b_seq_len.max().item() - - # ks/ke/lengths - seq_offsets = torch.empty_like(cum_seq) - seq_offsets[0] = 0 - if num_seq > 1: - seq_offsets[1:] = cum_seq[:-1] - - self.ks = seq_offsets - self.ke = (seq_offsets + 1).int() - self.lengths = b_seq_len.int() - - # page_table and req_all_mem_index - all_rows = self.req_manager.req_to_token_indexs[b_req_idx, :max_seq_len] - seq_range = torch.arange(max_seq_len, device=b_seq_len.device) - valid_mask = seq_range.unsqueeze(0) < b_seq_len.unsqueeze(1) - - page_table = torch.zeros((num_seq, max_seq_len), dtype=torch.int, device=b_seq_len.device) - page_table[valid_mask] = all_rows[valid_mask].int() - self.page_table_size_1 = page_table - - self.req_all_mem_index = all_rows[valid_mask] - - def _init_nsa_indexing_prefill(self): - """NSA indexing for prefill: variable q lengths, generic vectorized path.""" - b_seq_len = self.b_seq_len - b_q_seq_len = self.b_q_seq_len - b_req_idx = self.b_req_idx - num_seq = b_req_idx.shape[0] - device = b_seq_len.device - - max_seq_len = b_seq_len.max().item() - total_q_len = b_q_seq_len.sum().item() - - # page_table_size_1 and req_all_mem_index - all_rows = self.req_manager.req_to_token_indexs[b_req_idx, :max_seq_len] - seq_range = torch.arange(max_seq_len, device=device) - valid_mask = seq_range.unsqueeze(0) < b_seq_len.unsqueeze(1) - - page_table = torch.zeros((num_seq, max_seq_len), dtype=torch.int, device=device) - page_table[valid_mask] = all_rows[valid_mask].int() - self.page_table_size_1 = page_table - self.req_all_mem_index = all_rows[valid_mask] - - # ks, ke, lengths — generic vectorized for variable q lengths - cum_seq = torch.cumsum(b_seq_len, dim=0) - seq_offsets = torch.zeros_like(cum_seq) - seq_offsets[1:] = cum_seq[:-1] - - req_indices = torch.repeat_interleave(torch.arange(num_seq, device=device), b_q_seq_len) - - cum_q = torch.cumsum(b_q_seq_len, dim=0) - q_offsets = torch.zeros_like(cum_q) - q_offsets[1:] = cum_q[:-1] - token_in_req = torch.arange(total_q_len, device=device) - q_offsets[req_indices] - - self.ks = seq_offsets[req_indices].int() - self.ke = (seq_offsets[req_indices] + token_in_req + 1).int() - self.lengths = (b_seq_len[req_indices] - b_q_seq_len[req_indices] + token_in_req + 1).int() diff --git a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py index eb2c520330..43bf45f5c0 100644 --- a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py @@ -160,7 +160,6 @@ def get_indices( ks = infer_state.ks ke = infer_state.ke lengths = infer_state.lengths - page_table_1 = infer_state.page_table_size_1 if infer_state.is_prefill: mtp_step = 0 @@ -193,15 +192,17 @@ def get_indices( logits = deep_gemm.fp8_mqa_logits(q_fp8, (k_fp8_, k_scale_), weights.squeeze(-1), ks, ke) - from sgl_kernel import fast_topk_transform_fused + from sgl_kernel import fast_topk_v2 - return fast_topk_transform_fused( + b_topk_index = fast_topk_v2( score=logits, lengths=lengths, - page_table_size_1=page_table_1, - cu_seqlens_q=infer_state.b1_cu_q_seq_len, topk=self.index_topk, + row_starts=ke, ) + # 将 topk index 转化为 mem index + + return b_topk_index @staticmethod def _rotate_activation(x: torch.Tensor) -> torch.Tensor: diff --git a/lightllm/models/deepseek3_2/triton_kernel/topk_index_to_mem_index.py b/lightllm/models/deepseek3_2/triton_kernel/topk_index_to_mem_index.py new file mode 100644 index 0000000000..12786c6619 --- /dev/null +++ b/lightllm/models/deepseek3_2/triton_kernel/topk_index_to_mem_index.py @@ -0,0 +1,47 @@ +import torch + +import triton +import triton.language as tl + + +@triton.jit +def _trans_topk_index_to_mem_index( + topk_index, + topk_index_stride_b, + topk_index_stride_k, + ragged_mem_index, + topk_mem_index, + topk_mem_index_stride_b, + topk_mem_index_stride_k, + BLOCK_DMODEL: tl.constexpr, +): + cur_index = tl.program_id(0) + offs_d = tl.arange(0, BLOCK_DMODEL) + topk_index_ptrs = topk_index + cur_index * topk_index_stride_b + offs_d * topk_index_stride_k + topk_indices = tl.load(topk_index_ptrs) + + dest_mem_index = ragged_mem_index + topk_indices + mem_index = tl.load(dest_mem_index, mask=topk_indices != -1, other=-1) + tl.store(topk_mem_index + cur_index * topk_mem_index_stride_b + offs_d * topk_mem_index_stride_k, mem_index) + + +@torch.no_grad() +def trans_topk_index_to_mem_index(topk_index: torch.Tensor, ragged_mem_index: torch.Tensor): + assert topk_index.shape[1] == 2048, f"Expected topk_index shape[1]=2048, got {topk_index.shape[1]}" + + grid = (topk_index.shape[0],) + + topk_mem_index = torch.empty_like(topk_index) + + _trans_topk_index_to_mem_index[grid]( + topk_index=topk_index, + topk_index_stride_b=topk_index.stride(0), + topk_index_stride_k=topk_index.stride(1), + ragged_mem_index=ragged_mem_index, + topk_mem_index=topk_mem_index, + topk_mem_index_stride_b=topk_mem_index.stride(0), + topk_mem_index_stride_k=topk_mem_index.stride(1), + BLOCK_DMODEL=2048, + num_warps=8, + ) + return topk_mem_index diff --git a/unit_tests/models/deepseek3_2/triton_kernel/test_topk_index_to_mem_index.py b/unit_tests/models/deepseek3_2/triton_kernel/test_topk_index_to_mem_index.py new file mode 100644 index 0000000000..5e3aaf38d0 --- /dev/null +++ b/unit_tests/models/deepseek3_2/triton_kernel/test_topk_index_to_mem_index.py @@ -0,0 +1,24 @@ +import torch +import pytest +from lightllm.models.deepseek3_2.triton_kernel.topk_index_to_mem_index import trans_topk_index_to_mem_index + + +def test_trans_topk_index_to_mem_index(): + """Test trans_topk_index_to_mem_index converts topk indices to memory indices correctly.""" + batch_size = 1 + topk = 2048 + + # Create topk_index tensor with some valid indices and some -1 (padding) + topk_index = torch.zeros((batch_size, topk), dtype=torch.int32, device="cuda") + topk_index[:, 0:2048] = torch.arange(0, 2048, dtype=torch.int32, device="cuda") + + # Create ragged_mem_index lookup table + ragged_mem_index = torch.arange(0, 2048, dtype=torch.int32, device="cuda") + 10 + + topk_mem_index = trans_topk_index_to_mem_index(topk_index, ragged_mem_index) + + assert torch.equal(topk_mem_index, (torch.arange(0, 2048, dtype=torch.int32, device="cuda") + 10).view(1, -1)) + + +if __name__ == "__main__": + pytest.main() From adb7cad6b4a8bc0186e3259a16ce4e539d8ef75b Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Wed, 4 Mar 2026 06:40:25 +0000 Subject: [PATCH 42/58] fix --- .../basemodel/triton_kernel/gen_nsa_ks_ke.py | 38 ++++++++++++++++++- .../triton_kernel/test_gen_nsa_ks_ke.py | 21 +++++++++- 2 files changed, 56 insertions(+), 3 deletions(-) diff --git a/lightllm/common/basemodel/triton_kernel/gen_nsa_ks_ke.py b/lightllm/common/basemodel/triton_kernel/gen_nsa_ks_ke.py index ec30f176e5..13a4e1deb3 100644 --- a/lightllm/common/basemodel/triton_kernel/gen_nsa_ks_ke.py +++ b/lightllm/common/basemodel/triton_kernel/gen_nsa_ks_ke.py @@ -8,10 +8,16 @@ @triton.jit def _gen_nsa_ks_ke( b_seq_len, + b_req_idx, b_q_seq_len, b_same_req_mark, ks, ke, + lengths, + req_to_token_index, + strided_req_to_token_index_b, + strided_req_to_token_index_s, + ragged_mem_index, BLOCK_REQ: tl.constexpr, BLOCK_SEQ_SPLIT: tl.constexpr, ): @@ -29,6 +35,7 @@ def _gen_nsa_ks_ke( # 兼容 prefill 和 decode 的情况, decode 可能存在 mtp 的情况,各个请求会共享一个req对象,其处理比较特殊 q_seq_len = tl.load(b_q_seq_len + cur_index) + req_mark - 1 cur_total_len = tl.load(b_seq_len + cur_index) + cur_req_idx = tl.load(b_req_idx + cur_index) b_q_seq_len_data = tl.load(b_q_seq_len + off, (off < (cur_index - req_mark + 1)), other=0) store_start_index = tl.sum(b_q_seq_len_data) @@ -49,7 +56,27 @@ def _gen_nsa_ks_ke( ke_data + pre_sum_seq_len, mask=block_start + tl.arange(0, BLOCK_SEQ_SPLIT) < block_end, ) + tl.store( + lengths + store_start_index + block_start + tl.arange(0, BLOCK_SEQ_SPLIT), + ke_data - ks_data + 1, + mask=block_start + tl.arange(0, BLOCK_SEQ_SPLIT) < block_end, + ) + for block_index in range(tl.cdiv(cur_total_len, BLOCK_SEQ_SPLIT)): + block_start = block_index * BLOCK_SEQ_SPLIT + block_end = min(cur_total_len, (block_index + 1) * BLOCK_SEQ_SPLIT) + mask = block_start + tl.arange(0, BLOCK_SEQ_SPLIT) < block_end + + src_mem_index_ptr = ( + req_to_token_index + + strided_req_to_token_index_b * cur_req_idx + + block_start + + tl.arange(0, BLOCK_SEQ_SPLIT) + ) + src_mem_index = tl.load(src_mem_index_ptr, mask=mask, other=-1) + tl.store( + ragged_mem_index + pre_sum_seq_len + block_start + tl.arange(0, BLOCK_SEQ_SPLIT), src_mem_index, mask=mask + ) return @@ -58,23 +85,32 @@ def gen_nsa_ks_ke( b_seq_len: torch.Tensor, b_q_seq_len: torch.Tensor, b_req_idx: torch.Tensor, + req_to_token_index: torch.Tensor, q_token_num: int, + ragged_mem_index: torch.Tensor, ): batch_size = b_seq_len.shape[0] ks = torch.empty((q_token_num,), dtype=torch.int32, device=b_seq_len.device) ke = torch.empty((q_token_num,), dtype=torch.int32, device=b_seq_len.device) + lengths = torch.empty((q_token_num,), dtype=torch.int32, device=b_seq_len.device) b_same_req_mark = gen_same_req_mark(b_req_idx) _gen_nsa_ks_ke[(batch_size,)]( b_seq_len=b_seq_len, + b_req_idx=b_req_idx, b_q_seq_len=b_q_seq_len, b_same_req_mark=b_same_req_mark, ks=ks, ke=ke, + lengths=lengths, + req_to_token_index=req_to_token_index, + strided_req_to_token_index_b=req_to_token_index.stride(0), + strided_req_to_token_index_s=req_to_token_index.stride(1), + ragged_mem_index=ragged_mem_index, BLOCK_REQ=triton.next_power_of_2(batch_size), BLOCK_SEQ_SPLIT=256, ) - return ks, ke + return ks, ke, lengths @triton.jit diff --git a/unit_tests/common/basemodel/triton_kernel/test_gen_nsa_ks_ke.py b/unit_tests/common/basemodel/triton_kernel/test_gen_nsa_ks_ke.py index 8fcdc0e51a..6567479c25 100644 --- a/unit_tests/common/basemodel/triton_kernel/test_gen_nsa_ks_ke.py +++ b/unit_tests/common/basemodel/triton_kernel/test_gen_nsa_ks_ke.py @@ -29,10 +29,20 @@ def test_gen_nsa_ks_ke_basic(): ) q_token_num = b_q_seq_len.sum().item() - ks, ke = gen_nsa_ks_ke(b_seq_len, b_q_seq_len, b_req_idx, q_token_num) + req_to_token_index = torch.arange(0, 1000).cuda().view(100, -1) + ragged_mem_index = torch.empty_like(req_to_token_index.view(-1)) + + ks, ke, lengths = gen_nsa_ks_ke( + b_seq_len, b_q_seq_len, b_req_idx, req_to_token_index, q_token_num, ragged_mem_index + ) assert torch.equal(ks, torch.tensor([0, 0, 0, 0, 0], dtype=torch.int32, device="cuda")) assert torch.equal(ke, torch.tensor([5, 6, 7, 8, 9], dtype=torch.int32, device="cuda")) + assert torch.equal(lengths, ke - ks + 1) + assert torch.equal( + ragged_mem_index[0:10], torch.tensor([10, 11, 12, 13, 14, 15, 16, 17, 18, 19], dtype=torch.int32, device="cuda") + ) + return def test_gen_nsa_ks_ke_batch(): @@ -41,7 +51,12 @@ def test_gen_nsa_ks_ke_batch(): b_req_idx = torch.tensor([1, 1], dtype=torch.int32, device="cuda") q_token_num = b_q_seq_len.sum().item() - ks, ke = gen_nsa_ks_ke(b_seq_len, b_q_seq_len, b_req_idx, q_token_num) + req_to_token_index = torch.arange(0, 1000).cuda().view(10, -1) + ragged_mem_index = torch.empty_like(req_to_token_index.view(-1)) + + ks, ke, lengths = gen_nsa_ks_ke( + b_seq_len, b_q_seq_len, b_req_idx, req_to_token_index, q_token_num, ragged_mem_index + ) assert torch.equal( ks, @@ -55,6 +70,8 @@ def test_gen_nsa_ks_ke_batch(): ), ) assert torch.equal(ke, torch.tensor([9, 10], dtype=torch.int32, device="cuda")) + assert torch.equal(lengths, ke - ks + 1) + assert torch.equal(ragged_mem_index[0:11], torch.arange(100, 100 + 11).cuda()) def test_gen_same_req_mark(): From ab0ce7a3f83468de1593ba0a06465c05a2e89d01 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Wed, 4 Mar 2026 06:47:57 +0000 Subject: [PATCH 43/58] fix --- lightllm/models/deepseek3_2/model.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/lightllm/models/deepseek3_2/model.py b/lightllm/models/deepseek3_2/model.py index 09b6be4dbf..0fa0ea6c8f 100644 --- a/lightllm/models/deepseek3_2/model.py +++ b/lightllm/models/deepseek3_2/model.py @@ -3,7 +3,6 @@ from lightllm.models.deepseek2.model import Deepseek2TpPartModel from lightllm.models.deepseek3_2.layer_weights.transformer_layer_weight import Deepseek3_2TransformerLayerWeight from lightllm.models.deepseek3_2.layer_infer.transformer_layer_infer import Deepseek3_2TransformerLayerInfer -from lightllm.models.deepseek3_2.infer_struct import Deepseek3_2InferStateInfo from lightllm.common.basemodel.attention import get_nsa_prefill_att_backend_class, get_nsa_decode_att_backend_class @@ -106,9 +105,6 @@ class Deepseek3_2TpPartModel(Deepseek2TpPartModel): # infer class transformer_layer_infer_class = Deepseek3_2TransformerLayerInfer - # infer state class - infer_state_class = Deepseek3_2InferStateInfo - def _init_att_backend(self): self.prefill_att_backend = get_nsa_prefill_att_backend_class(index=0)(model=self) self.decode_att_backend = get_nsa_decode_att_backend_class(index=0)(model=self) From e069e9a5296bfb7c24fb0db2a7a76f14450b7491 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Wed, 4 Mar 2026 06:51:08 +0000 Subject: [PATCH 44/58] fix --- .../layer_infer/transformer_layer_infer.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py index 43bf45f5c0..2a42aa5659 100644 --- a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py @@ -1,8 +1,7 @@ import torch - +from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo from lightllm.models.deepseek2.layer_infer.transformer_layer_infer import Deepseek2TransformerLayerInfer from lightllm.models.deepseek3_2.layer_weights.transformer_layer_weight import Deepseek3_2TransformerLayerWeight -from lightllm.models.deepseek3_2._del_infer_struct import Deepseek3_2InferStateInfo from lightllm.common.basemodel.triton_kernel.norm.rmsnorm import rmsnorm_forward from lightllm.models.deepseek2.triton_kernel.rotary_emb import rotary_emb_fwd from lightllm.common.basemodel.attention.base_att import AttControl @@ -23,7 +22,7 @@ def __init__(self, layer_num, network_config): def _get_qkv( self, input: torch.Tensor, - infer_state: Deepseek3_2InferStateInfo, + infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek3_2TransformerLayerWeight, ) -> torch.Tensor: input = input.view(-1, self.embed_dim_) @@ -58,7 +57,7 @@ def _context_attention_kernel( self, q: torch.Tensor, kv, - infer_state: Deepseek3_2InferStateInfo, + infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek3_2TransformerLayerWeight, out=None, ) -> torch.Tensor: @@ -90,7 +89,7 @@ def _context_attention_kernel( def _token_attention_kernel( self, q, - infer_state: Deepseek3_2InferStateInfo, + infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek3_2TransformerLayerWeight, out=None, ): @@ -142,7 +141,7 @@ def get_indices( self, hidden_states: torch.Tensor, q_lora: torch.Tensor, - infer_state: Deepseek3_2InferStateInfo, + infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek3_2TransformerLayerWeight, ) -> torch.Tensor: @@ -217,7 +216,7 @@ def _get_q_k_bf16( self, hidden_states: torch.Tensor, q_lora: torch.Tensor, - infer_state: Deepseek3_2InferStateInfo, + infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek3_2TransformerLayerWeight, ): q = layer_weight.wq_b_proj_.mm(q_lora).view(-1, self.index_n_heads, self.index_head_dim) From abc836daa0bddfbd9eb96b42f37a11a5e4ab3d32 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Wed, 4 Mar 2026 07:38:17 +0000 Subject: [PATCH 45/58] fix --- .../attention/nsa/flashmla_sparse.py | 83 ++++++++++++++++--- 1 file changed, 72 insertions(+), 11 deletions(-) diff --git a/lightllm/common/basemodel/attention/nsa/flashmla_sparse.py b/lightllm/common/basemodel/attention/nsa/flashmla_sparse.py index 2c347ed32b..b7ad12e1d7 100644 --- a/lightllm/common/basemodel/attention/nsa/flashmla_sparse.py +++ b/lightllm/common/basemodel/attention/nsa/flashmla_sparse.py @@ -15,6 +15,14 @@ class NsaFlashMlaSparseAttBackend(BaseAttBackend): def __init__(self, model): super().__init__(model=model) + self.ragged_mem_buffers = [ + torch.empty( + model.graph_max_batch_size * model.max_seq_length, dtype=torch.int32, device=get_current_device_id() + ), + torch.empty( + model.graph_max_batch_size * model.max_seq_length, dtype=torch.int32, device=get_current_device_id() + ), + ] def create_att_prefill_state(self, infer_state: "InferStateInfo") -> "NsaFlashMlaSparsePrefillAttState": return NsaFlashMlaSparsePrefillAttState(backend=self, infer_state=infer_state) @@ -27,12 +35,29 @@ def create_att_decode_state(self, infer_state: "InferStateInfo") -> "NsaFlashMla class NsaFlashMlaSparsePrefillAttState(BasePrefillAttState): """Prefill attention state for NSA using flash_mla_sparse_fwd.""" - cu_seqlens_q: torch.Tensor = None - cu_seqlens_k: torch.Tensor = None + ks: torch.Tensor = None + ke: torch.Tensor = None + lengths: torch.Tensor = None + ragged_mem_index: torch.Tensor = None def init_state(self): - self.cu_seqlens_q = self.infer_state.b1_cu_q_seq_len.int() - self.cu_seqlens_k = self.infer_state.b1_cu_kv_seq_len.int() + self.backend: NsaFlashMlaSparseAttBackend = self.backend + self.ragged_mem_index = torch.empty( + self.infer_state.total_token_num, + dtype=torch.int32, + device=get_current_device_id(), + ) + from lightllm.common.basemodel.triton_kernel.gen_nsa_ks_ke import gen_nsa_ks_ke + + self.ks, self.ke, self.lengths = gen_nsa_ks_ke( + b_seq_len=self.infer_state.b_seq_len, + b_q_seq_len=self.infer_state.b_q_seq_len, + b_req_idx=self.infer_state.b_req_idx, + req_to_token_index=self.infer_state.req_manager.req_to_token_indexs, + q_token_num=self.infer_state.total_token_num - self.infer_state.prefix_total_token_num, + ragged_mem_index=self.ragged_mem_index, + ) + return def prefill_att( self, @@ -77,11 +102,49 @@ def _nsa_prefill_att( class NsaFlashMlaSparseDecodeAttState(BaseDecodeAttState): cu_seqlens_q: torch.Tensor = None - cu_seqlens_k: torch.Tensor = None + ks: torch.Tensor = None + ke: torch.Tensor = None + length: torch.Tensor = None + ragged_mem_index: torch.Tensor = None + nsa_cache_seqlens: torch.Tensor = None + nsa_cu_seqlens_k_new: torch.Tensor = None def init_state(self): self.cu_seqlens_q = self.infer_state.b1_cu_q_seq_len.int() - self.cu_seqlens_k = self.infer_state.b1_cu_kv_seq_len.int() + + self.backend: NsaFlashMlaSparseAttBackend = self.backend + model = self.backend.model + use_cuda_graph = ( + self.infer_state.batch_size <= model.graph_max_batch_size + and self.infer_state.max_kv_seq_len <= model.graph_max_len_in_batch + ) + + if use_cuda_graph: + self.ragged_mem_index = self.backend.ragged_mem_buffers[self.infer_state.microbatch_index] + else: + self.ragged_mem_index = torch.empty( + self.infer_state.total_token_num, + dtype=torch.int32, + device=get_current_device_id(), + ) + + from lightllm.common.basemodel.triton_kernel.gen_nsa_ks_ke import gen_nsa_ks_ke + + self.ks, self.ke, self.lengths = gen_nsa_ks_ke( + b_seq_len=self.infer_state.b_seq_len, + b_q_seq_len=self.infer_state.b_q_seq_len, + b_req_idx=self.infer_state.b_req_idx, + req_to_token_index=self.infer_state.req_manager.req_to_token_indexs, + q_token_num=self.infer_state.b_seq_len.shape[0], + ragged_mem_index=self.ragged_mem_index, + ) + self.nsa_cache_seqlens = torch.minimum( + torch.full(size=(self.infer_state.batch_size,), value=2048), self.infer_state.b_seq_len + ) + padded_seq_lens = torch.zeros(size=(self.nsa_cache_seqlens.shape[0] + 1,), dtype=torch.int32, device="cuda") + # 进行 cumsum 操作 + padded_seq_lens[1:].copy_(self.nsa_cache_seqlens, non_blocking=True) + self.nsa_cu_seqlens_k_new = padded_seq_lens.cumsum(dim=0, dtype=torch.int32) def decode_att( self, @@ -106,8 +169,6 @@ def _nsa_decode_att( nsa_dict = att_control.nsa_decode_dict topk_indices = nsa_dict["topk_indices"] - nsa_cache_seqlens = nsa_dict["nsa_cache_seqlens"] - nsa_cu_seqlens_k = nsa_dict["nsa_cu_seqlens_k"] softmax_scale = nsa_dict["softmax_scale"] kv_lora_rank = nsa_dict["kv_lora_rank"] qk_rope_head_dim = nsa_dict["qk_rope_head_dim"] @@ -124,9 +185,9 @@ def _nsa_decode_att( v_cache=kv_nope, qv=q_nope, page_table=topk_indices, - cache_seqlens=nsa_cache_seqlens, - cu_seqlens_q=self.cu_seqlens_q, - cu_seqlens_k_new=nsa_cu_seqlens_k, + cache_seqlens=self.nsa_cache_seqlens, + cu_seqlens_q=self.infer_state.b1_cu_q_seq_len, + cu_seqlens_k_new=self.nsa_cu_seqlens_k_new, max_seqlen_q=self.infer_state.max_q_seq_len, softmax_scale=softmax_scale, causal=True, From 63eb4e1b5e1ac30158d38b063994986b03149bc4 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Wed, 4 Mar 2026 07:49:03 +0000 Subject: [PATCH 46/58] fix --- .../layer_infer/transformer_layer_infer.py | 30 +++++++++---------- 1 file changed, 14 insertions(+), 16 deletions(-) diff --git a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py index 2a42aa5659..011cc341cd 100644 --- a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py @@ -1,10 +1,12 @@ import torch +from typing import Union from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo from lightllm.models.deepseek2.layer_infer.transformer_layer_infer import Deepseek2TransformerLayerInfer from lightllm.models.deepseek3_2.layer_weights.transformer_layer_weight import Deepseek3_2TransformerLayerWeight from lightllm.common.basemodel.triton_kernel.norm.rmsnorm import rmsnorm_forward from lightllm.models.deepseek2.triton_kernel.rotary_emb import rotary_emb_fwd from lightllm.common.basemodel.attention.base_att import AttControl +from lightllm.common.basemodel.attention.nsa import NsaFlashMlaSparsePrefillAttState, NsaFlashMlaSparseDecodeAttState from lightllm.models.deepseek3_2.triton_kernel.act_quant import act_quant from lightllm.models.deepseek3_2.triton_kernel.destindex_copy_indexer_ks import destindex_copy_indexer_ks from lightllm.models.deepseek3_2.triton_kernel.extract_indexer_ks import extract_indexer_ks @@ -142,6 +144,7 @@ def get_indices( hidden_states: torch.Tensor, q_lora: torch.Tensor, infer_state: Deepseek2InferStateInfo, + att_state: Union[NsaFlashMlaSparsePrefillAttState, NsaFlashMlaSparseDecodeAttState], layer_weight: Deepseek3_2TransformerLayerWeight, ) -> torch.Tensor: @@ -156,9 +159,9 @@ def get_indices( weights = layer_weight.weights_proj_.mm(hidden_states) * self.index_n_heads_scale weights = weights.unsqueeze(-1) * q_scale - ks = infer_state.ks - ke = infer_state.ke - lengths = infer_state.lengths + ks = att_state.ks + ke = att_state.ke + lengths = att_state.lengths if infer_state.is_prefill: mtp_step = 0 @@ -175,18 +178,6 @@ def get_indices( mtp_step=mtp_step, ) - # Get actual sequence length from q (which comes from q_lora) - # This may differ from ks.shape[0] during certain operations - actual_seq_len = q.shape[0] - - # ks, ke, lengths, and weights should all match actual_seq_len - # Slice them if they don't match - if ks.shape[0] != actual_seq_len: - ks = ks[:actual_seq_len] - ke = ke[:actual_seq_len] - lengths = lengths[:actual_seq_len] - weights = weights[:actual_seq_len] - import deep_gemm logits = deep_gemm.fp8_mqa_logits(q_fp8, (k_fp8_, k_scale_), weights.squeeze(-1), ks, ke) @@ -197,10 +188,17 @@ def get_indices( score=logits, lengths=lengths, topk=self.index_topk, - row_starts=ke, + row_starts=ks, ) # 将 topk index 转化为 mem index + from ..triton_kernel.topk_index_to_mem_index import trans_topk_index_to_mem_index + + b_topk_index = trans_topk_index_to_mem_index( + topk_index=b_topk_index, + ragged_mem_index=att_state.ragged_mem_index, + ) + return b_topk_index @staticmethod From c8d3b42bba296d19c284a57df89a6b2e971a2096 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Wed, 4 Mar 2026 07:50:46 +0000 Subject: [PATCH 47/58] fix --- .../models/deepseek3_2/layer_infer/transformer_layer_infer.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py index 011cc341cd..14868c352a 100644 --- a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py @@ -104,8 +104,6 @@ def _token_attention_kernel( nsa_decode=True, nsa_decode_dict={ "topk_indices": infer_state.topk_indices, - "nsa_cache_seqlens": infer_state.nsa_cache_seqlens, - "nsa_cu_seqlens_k": infer_state.nsa_cu_seqlens_k, "softmax_scale": self.softmax_scale, "kv_lora_rank": self.kv_lora_rank, "qk_rope_head_dim": self.qk_rope_head_dim, From 1fc7a10a87e22aca928e101c41b2a81a98e67979 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Wed, 4 Mar 2026 20:32:37 +0800 Subject: [PATCH 48/58] fix --- .../basemodel/attention/nsa/flashmla_sparse.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/lightllm/common/basemodel/attention/nsa/flashmla_sparse.py b/lightllm/common/basemodel/attention/nsa/flashmla_sparse.py index b7ad12e1d7..42eb0d2937 100644 --- a/lightllm/common/basemodel/attention/nsa/flashmla_sparse.py +++ b/lightllm/common/basemodel/attention/nsa/flashmla_sparse.py @@ -15,13 +15,10 @@ class NsaFlashMlaSparseAttBackend(BaseAttBackend): def __init__(self, model): super().__init__(model=model) + device = get_current_device_id() self.ragged_mem_buffers = [ - torch.empty( - model.graph_max_batch_size * model.max_seq_length, dtype=torch.int32, device=get_current_device_id() - ), - torch.empty( - model.graph_max_batch_size * model.max_seq_length, dtype=torch.int32, device=get_current_device_id() - ), + torch.empty(model.graph_max_batch_size * model.max_seq_length, dtype=torch.int32, device=device) + for _ in range(2) ] def create_att_prefill_state(self, infer_state: "InferStateInfo") -> "NsaFlashMlaSparsePrefillAttState": @@ -101,7 +98,6 @@ def _nsa_prefill_att( @dataclasses.dataclass class NsaFlashMlaSparseDecodeAttState(BaseDecodeAttState): - cu_seqlens_q: torch.Tensor = None ks: torch.Tensor = None ke: torch.Tensor = None length: torch.Tensor = None @@ -110,8 +106,6 @@ class NsaFlashMlaSparseDecodeAttState(BaseDecodeAttState): nsa_cu_seqlens_k_new: torch.Tensor = None def init_state(self): - self.cu_seqlens_q = self.infer_state.b1_cu_q_seq_len.int() - self.backend: NsaFlashMlaSparseAttBackend = self.backend model = self.backend.model use_cuda_graph = ( From ad6c8deb7456b65f5d1f27538fd2ef09dbf49e20 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Wed, 4 Mar 2026 20:40:11 +0800 Subject: [PATCH 49/58] add comments --- lightllm/common/basemodel/triton_kernel/gen_nsa_ks_ke.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/lightllm/common/basemodel/triton_kernel/gen_nsa_ks_ke.py b/lightllm/common/basemodel/triton_kernel/gen_nsa_ks_ke.py index 13a4e1deb3..af764898ec 100644 --- a/lightllm/common/basemodel/triton_kernel/gen_nsa_ks_ke.py +++ b/lightllm/common/basemodel/triton_kernel/gen_nsa_ks_ke.py @@ -132,6 +132,14 @@ def _gen_same_req_mark(b_req_idx, b_same_req_mark, BLOCK_SIZE: tl.constexpr): @torch.no_grad() def gen_same_req_mark(b_req_idx: torch.Tensor): + """ + b_req_idx: torch.Tensor + out: torch.Tensor + + demo: + b_req_idx = [1, 1, 2, 3, 3, 3] + out = [0, 2, 1, 0, 0, 3] + """ batch_size = b_req_idx.shape[0] b_same_req_mark = torch.empty((batch_size,), dtype=torch.int32, device=b_req_idx.device) b_same_req_mark.fill_(0) From 6e4691c82edd4ed09907a3536ac678d4e89178c8 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Wed, 4 Mar 2026 20:56:21 +0800 Subject: [PATCH 50/58] fix --- .../basemodel/triton_kernel/gen_nsa_ks_ke.py | 21 +++++++++++-------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/lightllm/common/basemodel/triton_kernel/gen_nsa_ks_ke.py b/lightllm/common/basemodel/triton_kernel/gen_nsa_ks_ke.py index af764898ec..27aa5f8e35 100644 --- a/lightllm/common/basemodel/triton_kernel/gen_nsa_ks_ke.py +++ b/lightllm/common/basemodel/triton_kernel/gen_nsa_ks_ke.py @@ -22,15 +22,16 @@ def _gen_nsa_ks_ke( BLOCK_SEQ_SPLIT: tl.constexpr, ): cur_index = tl.program_id(0) - # 不处于边界mark的最后一个req不进行处理。 + # 只处理最后一个同样req_idx的req进行处理,代表seq_len最长的那个。 + # req_mark 为 0,表示不是最后一个。 req_mark = tl.load(b_same_req_mark + cur_index) if req_mark == 0: return off = tl.arange(0, BLOCK_REQ) b_same_req_mark = tl.load(b_same_req_mark + off, off < cur_index, other=0) - b_seq_len_data = tl.load(b_seq_len + off, (off < cur_index) & (b_same_req_mark != 0), other=0) - pre_sum_seq_len = tl.sum(b_seq_len_data) + pre_b_seq_len_data = tl.load(b_seq_len + off, (off < cur_index) & (b_same_req_mark != 0), other=0) + pre_sum_seq_len = tl.sum(pre_b_seq_len_data) # 兼容 prefill 和 decode 的情况, decode 可能存在 mtp 的情况,各个请求会共享一个req对象,其处理比较特殊 q_seq_len = tl.load(b_q_seq_len + cur_index) + req_mark - 1 @@ -43,23 +44,25 @@ def _gen_nsa_ks_ke( for block_index in range(tl.cdiv(q_seq_len, BLOCK_SEQ_SPLIT)): block_start = block_index * BLOCK_SEQ_SPLIT block_end = min(q_seq_len, (block_index + 1) * BLOCK_SEQ_SPLIT) + block_off = block_start + tl.arange(0, BLOCK_SEQ_SPLIT) + mask = block_off < block_end ks_data = tl.zeros((BLOCK_SEQ_SPLIT,), dtype=tl.int32) ke_data = (cur_total_len - q_seq_len) + tl.arange(0, BLOCK_SEQ_SPLIT) tl.store( - ks + store_start_index + block_start + tl.arange(0, BLOCK_SEQ_SPLIT), + ks + store_start_index + block_off, ks_data + pre_sum_seq_len, - mask=block_start + tl.arange(0, BLOCK_SEQ_SPLIT) < block_end, + mask=mask, ) tl.store( - ke + store_start_index + block_start + tl.arange(0, BLOCK_SEQ_SPLIT), + ke + store_start_index + block_off, ke_data + pre_sum_seq_len, - mask=block_start + tl.arange(0, BLOCK_SEQ_SPLIT) < block_end, + mask=mask, ) tl.store( - lengths + store_start_index + block_start + tl.arange(0, BLOCK_SEQ_SPLIT), + lengths + store_start_index + block_off, ke_data - ks_data + 1, - mask=block_start + tl.arange(0, BLOCK_SEQ_SPLIT) < block_end, + mask=mask, ) for block_index in range(tl.cdiv(cur_total_len, BLOCK_SEQ_SPLIT)): From 1e4a4675cdc7e0b25f16ebfb1aad09a6640ff7dd Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Wed, 4 Mar 2026 21:14:10 +0800 Subject: [PATCH 51/58] fix --- lightllm/models/deepseek3_2/model.py | 3 +-- lightllm/server/api_cli.py | 9 +++------ lightllm/server/tokenizer.py | 26 +++++++++++++------------- 3 files changed, 17 insertions(+), 21 deletions(-) diff --git a/lightllm/models/deepseek3_2/model.py b/lightllm/models/deepseek3_2/model.py index 0fa0ea6c8f..f3b5e12ad5 100644 --- a/lightllm/models/deepseek3_2/model.py +++ b/lightllm/models/deepseek3_2/model.py @@ -12,8 +12,7 @@ class DeepSeekV32Tokenizer: DeepSeek-V3.2's tokenizer_config.json does not ship with a Jinja chat template, so ``apply_chat_template`` would fail without either a manually - supplied ``--chat_template`` file or this wrapper. Activate it with - ``--tokenizer_mode deepseek_v32``. + supplied ``--chat_template`` file or this wrapper. """ def __init__(self, tokenizer): diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 121c272973..4b92298b6b 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -92,12 +92,9 @@ def make_argument_parser() -> argparse.ArgumentParser: "--tokenizer_mode", type=str, default="fast", - help="""tokenizer load mode, can be slow, fast, auto, or deepseek_v32. - slow mode load fast but run slow, good for debug and test. - fast mode get best performance. - auto mode will try to use fast mode, if failed will use slow mode. - deepseek_v32 mode wraps the tokenizer with Python-based DSML chat - template encoding for DeepSeek-V3.2 models (no --chat_template needed).""", + help="""tokenizer load mode, can be slow, fast or auto, slow mode load fast but run slow, + slow mode is good for debug and test, fast mode get best performance, auto mode will + try to use fast mode, if failed will use slow mode""", ) parser.add_argument( "--load_way", diff --git a/lightllm/server/tokenizer.py b/lightllm/server/tokenizer.py index bbbb052b85..2800bf0f6b 100644 --- a/lightllm/server/tokenizer.py +++ b/lightllm/server/tokenizer.py @@ -44,20 +44,7 @@ def get_tokenizer( *args, **kwargs, ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: - model_cfg, _ = PretrainedConfig.get_config_dict(tokenizer_name) - model_type = model_cfg.get("model_type", "") """Gets a tokenizer for the given model name via Huggingface.""" - # DeepSeek-V3.2 custom tokenizer mode: wraps the HF tokenizer with - # a Python-based apply_chat_template that uses encoding_dsv32.py. - if model_type == "deepseek_v32": - from ..models.deepseek3_2.model import DeepSeekV32Tokenizer - - hf_tokenizer = AutoTokenizer.from_pretrained( - tokenizer_name, trust_remote_code=trust_remote_code, *args, **kwargs - ) - logger.info("Using DeepSeek-V3.2 tokenizer mode with Python-based chat template encoding.") - return DeepSeekV32Tokenizer(hf_tokenizer) - if tokenizer_mode == "slow": if kwargs.get("use_fast", False): raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.") @@ -89,6 +76,19 @@ def get_tokenizer( "slowdown. Consider using a fast tokenizer instead." ) + model_cfg, _ = PretrainedConfig.get_config_dict(tokenizer_name) + model_type = model_cfg.get("model_type", "") + # DeepSeek-V3.2 custom tokenizer mode: wraps the HF tokenizer with + # a Python-based apply_chat_template that uses encoding_dsv32.py. + if model_type == "deepseek_v32": + from ..models.deepseek3_2.model import DeepSeekV32Tokenizer + + hf_tokenizer = AutoTokenizer.from_pretrained( + tokenizer_name, trust_remote_code=trust_remote_code, *args, **kwargs + ) + logger.info("Using DeepSeek-V3.2 tokenizer mode with Python-based chat template encoding.") + return DeepSeekV32Tokenizer(hf_tokenizer) + if model_cfg["architectures"][0] == "TarsierForConditionalGeneration": from ..models.qwen2_vl.vision_process import Qwen2VLImageProcessor From f8241e9a0a78c3b1a68b2a1111986b9bba7d953d Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Wed, 4 Mar 2026 21:17:49 +0800 Subject: [PATCH 52/58] fix --- lightllm/models/deepseek3_2/model.py | 30 ++++++++++++++-------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/lightllm/models/deepseek3_2/model.py b/lightllm/models/deepseek3_2/model.py index f3b5e12ad5..5831044311 100644 --- a/lightllm/models/deepseek3_2/model.py +++ b/lightllm/models/deepseek3_2/model.py @@ -6,6 +6,21 @@ from lightllm.common.basemodel.attention import get_nsa_prefill_att_backend_class, get_nsa_decode_att_backend_class +@ModelRegistry(["deepseek_v32"]) +class Deepseek3_2TpPartModel(Deepseek2TpPartModel): + + # weight class + transformer_weight_class = Deepseek3_2TransformerLayerWeight + + # infer class + transformer_layer_infer_class = Deepseek3_2TransformerLayerInfer + + def _init_att_backend(self): + self.prefill_att_backend = get_nsa_prefill_att_backend_class(index=0)(model=self) + self.decode_att_backend = get_nsa_decode_att_backend_class(index=0)(model=self) + return + + class DeepSeekV32Tokenizer: """Tokenizer wrapper for DeepSeek-V3.2 that uses the Python-based encoding_dsv32 module instead of Jinja chat templates. @@ -93,18 +108,3 @@ def apply_chat_template( if tokenize: return self.tokenizer.encode(prompt, add_special_tokens=False) return prompt - - -@ModelRegistry(["deepseek_v32"]) -class Deepseek3_2TpPartModel(Deepseek2TpPartModel): - - # weight class - transformer_weight_class = Deepseek3_2TransformerLayerWeight - - # infer class - transformer_layer_infer_class = Deepseek3_2TransformerLayerInfer - - def _init_att_backend(self): - self.prefill_att_backend = get_nsa_prefill_att_backend_class(index=0)(model=self) - self.decode_att_backend = get_nsa_decode_att_backend_class(index=0)(model=self) - return From bf13746adf5ca699c39a2d6ad5cf5cce770ffa61 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Wed, 4 Mar 2026 21:31:02 +0800 Subject: [PATCH 53/58] fix --- .../deepseek3_2/layer_infer/transformer_layer_infer.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py index 14868c352a..0a5023f849 100644 --- a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py @@ -33,8 +33,10 @@ def _get_qkv( [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1 ) q = rmsnorm_forward(q, weight=layer_weight.q_a_layernorm_.weight, eps=self.eps_) - - infer_state.topk_indices = self.indexer.get_indices(input, q, infer_state, layer_weight) + att_state = infer_state.prefill_att_state if infer_state.is_prefill else infer_state.decode_att_state + infer_state.topk_indices = self.indexer.get_indices( + hidden_states=input, q_lora=q, infer_state=infer_state, att_state=att_state, layer_weight=layer_weight + ) q = layer_weight.q_b_proj_.mm(q) cache_kv = cache_kv.view(-1, 1, self.kv_lora_rank + self.qk_rope_head_dim) From 3d5eacd032b5d95f30c2ed644436abc425063c36 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Thu, 5 Mar 2026 02:49:21 +0000 Subject: [PATCH 54/58] fix --- .../basemodel/triton_kernel/quantization/fp8act_quant_kernel.py | 2 +- lightllm/models/deepseek3_2/triton_kernel/extract_indexer_ks.py | 2 +- lightllm/server/api_cli.py | 2 +- test/acc/test_deepseekr1_mtp_ep.sh | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/lightllm/common/basemodel/triton_kernel/quantization/fp8act_quant_kernel.py b/lightllm/common/basemodel/triton_kernel/quantization/fp8act_quant_kernel.py index 5402a9caf1..0a68372887 100644 --- a/lightllm/common/basemodel/triton_kernel/quantization/fp8act_quant_kernel.py +++ b/lightllm/common/basemodel/triton_kernel/quantization/fp8act_quant_kernel.py @@ -154,7 +154,7 @@ def per_token_group_quant_fp8( ) # 使用SGL kernel进行量化 - sgl_ops.sgl_per_token_group_quant_fp8(x, x_q, x_s, group_size, 1e-10, fp8_min, fp8_max, False) + sgl_ops.sgl_per_token_group_quant_fp8(x, x_q, x_s, group_size, 1e-10, fp8_min, fp8_max, False, enable_v2=True) else: # 使用LightLLM kernel进行量化 x_s = alloc_func( diff --git a/lightllm/models/deepseek3_2/triton_kernel/extract_indexer_ks.py b/lightllm/models/deepseek3_2/triton_kernel/extract_indexer_ks.py index 55b967e066..d0f8b45f81 100644 --- a/lightllm/models/deepseek3_2/triton_kernel/extract_indexer_ks.py +++ b/lightllm/models/deepseek3_2/triton_kernel/extract_indexer_ks.py @@ -79,7 +79,7 @@ def extract_indexer_ks( # Allocate output tensors O_fp8 = torch.empty((out_token_num // (mtp_step + 1), head_dim), dtype=torch.float8_e4m3fn, device=I_buffer.device) - O_scale = torch.empty((out_token_num // (mtp_step + 1),), dtype=torch.float32, device=I_buffer.device) + O_scale = torch.empty((out_token_num // (mtp_step + 1), 1), dtype=torch.float32, device=I_buffer.device) assert b_seq_len.shape[0] % (mtp_step + 1) == 0 grid = (b_seq_len.shape[0] // (mtp_step + 1), min(256, max_kv_seq_len)) diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 4b92298b6b..be24043538 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -259,7 +259,7 @@ def make_argument_parser() -> argparse.ArgumentParser: ) parser.add_argument("--disable_dynamic_prompt_cache", action="store_true", help="disable dynamic prompt cache") - parser.add_argument("--chunked_prefill_size", type=int, default=8192, help="chunked prefill size") + parser.add_argument("--chunked_prefill_size", type=int, default=None, help="chunked prefill size") parser.add_argument("--disable_chunked_prefill", action="store_true", help="whether to disable chunked prefill") parser.add_argument("--diverse_mode", action="store_true", help="diversity generation mode") parser.add_argument("--token_healing_mode", action="store_true", help="code model infer mode") diff --git a/test/acc/test_deepseekr1_mtp_ep.sh b/test/acc/test_deepseekr1_mtp_ep.sh index 7ceb1658c8..a7c9df1af8 100644 --- a/test/acc/test_deepseekr1_mtp_ep.sh +++ b/test/acc/test_deepseekr1_mtp_ep.sh @@ -1,3 +1,3 @@ -LOADWORKER=18 MOE_MODE=EP NUM_MAX_DISPATCH_TOKENS_PER_RANK=256 python -m lightllm.server.api_server --model_dir /mtc/models/DeepSeek-R1 --tp 8 --dp 8 --port 8089 --max_total_token_num 60000 --graph_max_batch_size 16 --batch_max_tokens 6000 --mtp_mode eagle_with_att --mtp_draft_model_dir /mtc/models/DeepSeek-R1-NextN --mtp_step 2 +LOADWORKER=18 NUM_MAX_DISPATCH_TOKENS_PER_RANK=256 python -m lightllm.server.api_server --enable_ep_moe --model_dir /mtc/models/DeepSeek-R1 --tp 8 --dp 8 --port 8089 --max_total_token_num 60000 --graph_max_batch_size 16 --batch_max_tokens 6000 --mtp_mode eagle_with_att --mtp_draft_model_dir /mtc/models/DeepSeek-R1-NextN --mtp_step 2 HF_ALLOW_CODE_EVAL=1 HF_DATASETS_OFFLINE=0 lm_eval --model local-completions --model_args '{"model":"deepseek-ai/DeepSeek-R1", "base_url":"http://localhost:8089/v1/completions", "max_length": 16384}' --tasks gsm8k --batch_size 32 --confirm_run_unsafe_code \ No newline at end of file From 554b117962240bef82cec5eb0c125b90e26f09e1 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Thu, 5 Mar 2026 10:07:35 +0000 Subject: [PATCH 55/58] fix --- docker/Dockerfile | 7 ++++--- docker/scripts/build.sh | 3 ++- .../common/basemodel/attention/nsa/flashmla_sparse.py | 2 +- requirements.txt | 9 ++++----- test/acc/test_deepseekr1_mtp_ep.sh | 2 +- 5 files changed, 12 insertions(+), 11 deletions(-) diff --git a/docker/Dockerfile b/docker/Dockerfile index 8f73a603cc..cd2bcc6f95 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -3,7 +3,7 @@ FROM nvidia/cuda:${CUDA_VERSION}-cudnn-devel-ubuntu22.04 ARG PYTHON_VERSION=3.10 ARG MAMBA_VERSION=24.7.1-0 -ARG VLLM_VERSION=0.11.0 +ARG VLLM_VERSION=0.16.0 ARG TARGETPLATFORM ARG ENABLE_DEEPEP=1 ARG ENABLE_NIXL=1 @@ -20,14 +20,15 @@ RUN chmod 777 -R /tmp && \ curl \ g++ \ make \ - git && \ + git \ + wget && \ rm -rf /var/lib/apt/lists/* RUN case ${TARGETPLATFORM} in \ "linux/arm64") MAMBA_ARCH=aarch64 ;; \ *) MAMBA_ARCH=x86_64 ;; \ esac && \ - curl -fsSL -o ~/mambaforge.sh "https://github.com/conda-forge/miniforge/releases/download/${MAMBA_VERSION}/Mambaforge-${MAMBA_VERSION}-Linux-${MAMBA_ARCH}.sh" && \ + wget -O ~/mambaforge.sh "https://github.com/conda-forge/miniforge/releases/download/${MAMBA_VERSION}/Mambaforge-${MAMBA_VERSION}-Linux-${MAMBA_ARCH}.sh" && \ bash ~/mambaforge.sh -b -p /opt/conda && \ rm ~/mambaforge.sh diff --git a/docker/scripts/build.sh b/docker/scripts/build.sh index 1699b39dd7..355d6c65b3 100644 --- a/docker/scripts/build.sh +++ b/docker/scripts/build.sh @@ -100,5 +100,6 @@ DOCKER_BUILDKIT=1 docker build -f docker/Dockerfile \ --build-arg ENABLE_DEEPEP="${ENABLE_DEEPEP}" \ --build-arg ENABLE_NIXL="${ENABLE_NIXL}" \ --build-arg ENABLE_CACHE="${ENABLE_CACHE}" \ - -t "${IMAGE_PREFIX}:${IMAGE_TAG}" . + --progress=plain \ + -t "${IMAGE_PREFIX}:${IMAGE_TAG}" . diff --git a/lightllm/common/basemodel/attention/nsa/flashmla_sparse.py b/lightllm/common/basemodel/attention/nsa/flashmla_sparse.py index 42eb0d2937..27967ded72 100644 --- a/lightllm/common/basemodel/attention/nsa/flashmla_sparse.py +++ b/lightllm/common/basemodel/attention/nsa/flashmla_sparse.py @@ -133,7 +133,7 @@ def init_state(self): ragged_mem_index=self.ragged_mem_index, ) self.nsa_cache_seqlens = torch.minimum( - torch.full(size=(self.infer_state.batch_size,), value=2048), self.infer_state.b_seq_len + torch.full(size=(self.infer_state.batch_size,), fill_value=2048, device="cuda"), self.infer_state.b_seq_len ) padded_seq_lens = torch.zeros(size=(self.nsa_cache_seqlens.shape[0] + 1,), dtype=torch.int32, device="cuda") # 进行 cumsum 操作 diff --git a/requirements.txt b/requirements.txt index 25cdab955d..5b0b201ae3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -59,7 +59,7 @@ six==1.16.0 sniffio==1.3.0 sortedcontainers==2.4.0 toolz==0.12.0 -torch==2.8.0 +torch==2.9.1 tqdm==4.65.0 transformers==4.57.1 tokenizers==0.22.1 @@ -80,16 +80,15 @@ frozendict==2.4.6 atomics==1.0.3 easydict==1.13 hypercorn==0.18.0 -flashinfer-python==0.2.4 -sgl-kernel==0.3.7.post1 +flashinfer-python==0.6.3 +sgl-kernel==0.3.21 httpx==0.28.1 librosa==0.11.0 cuda_bindings==12.9.0 orjson==3.11.2 setproctitle==1.3.6 -xformers==0.0.32.post1 xxhash==3.6.0 -torchvision==0.23.0 +torchvision==0.24.1 interegular==0.3.3 partial_json_parser==0.2.1.1.post6 websockets==15.0.1 diff --git a/test/acc/test_deepseekr1_mtp_ep.sh b/test/acc/test_deepseekr1_mtp_ep.sh index a7c9df1af8..271176a571 100644 --- a/test/acc/test_deepseekr1_mtp_ep.sh +++ b/test/acc/test_deepseekr1_mtp_ep.sh @@ -1,3 +1,3 @@ -LOADWORKER=18 NUM_MAX_DISPATCH_TOKENS_PER_RANK=256 python -m lightllm.server.api_server --enable_ep_moe --model_dir /mtc/models/DeepSeek-R1 --tp 8 --dp 8 --port 8089 --max_total_token_num 60000 --graph_max_batch_size 16 --batch_max_tokens 6000 --mtp_mode eagle_with_att --mtp_draft_model_dir /mtc/models/DeepSeek-R1-NextN --mtp_step 2 +LOADWORKER=18 NUM_MAX_DISPATCH_TOKENS_PER_RANK=256 python -m lightllm.server.api_server --enable_ep_moe --model_dir /mtc/sufubao/DeepSeek-V3.2 --tp 8 --dp 8 --port 8089 --max_total_token_num 60000 --graph_max_batch_size 16 --batch_max_tokens 6000 --mtp_mode eagle_with_att --mtp_draft_model_dir /mtc/models/DeepSeek-R1-NextN --mtp_step 2 HF_ALLOW_CODE_EVAL=1 HF_DATASETS_OFFLINE=0 lm_eval --model local-completions --model_args '{"model":"deepseek-ai/DeepSeek-R1", "base_url":"http://localhost:8089/v1/completions", "max_length": 16384}' --tasks gsm8k --batch_size 32 --confirm_run_unsafe_code \ No newline at end of file From cb0b6e267fed3c07730e83c4bdaf5d58aac034a4 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Fri, 6 Mar 2026 06:34:41 +0000 Subject: [PATCH 56/58] fix --- docker/Dockerfile | 1 - lightllm/common/basemodel/attention/nsa/flashmla_sparse.py | 3 ++- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docker/Dockerfile b/docker/Dockerfile index cd2bcc6f95..e766107ae7 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -45,7 +45,6 @@ COPY ./requirements.txt /lightllm/requirements.txt RUN pip install -U pip RUN pip install -r /lightllm/requirements.txt --no-cache-dir RUN pip install --no-cache-dir vllm==${VLLM_VERSION} -RUN pip install https://github.com/ModelTC/LightKernel/releases/download/v1.0.1/lightllm_kernel-0.1.0-cp310-cp310-linux_x86_64.whl RUN apt-get update && apt-get install -y libnuma-dev && rm -rf /var/lib/apt/lists/* diff --git a/lightllm/common/basemodel/attention/nsa/flashmla_sparse.py b/lightllm/common/basemodel/attention/nsa/flashmla_sparse.py index 27967ded72..1e549d59f8 100644 --- a/lightllm/common/basemodel/attention/nsa/flashmla_sparse.py +++ b/lightllm/common/basemodel/attention/nsa/flashmla_sparse.py @@ -133,7 +133,8 @@ def init_state(self): ragged_mem_index=self.ragged_mem_index, ) self.nsa_cache_seqlens = torch.minimum( - torch.full(size=(self.infer_state.batch_size,), fill_value=2048, device="cuda"), self.infer_state.b_seq_len + torch.full(size=(self.infer_state.batch_size,), fill_value=2048, dtype=torch.int32, device="cuda"), + self.infer_state.b_seq_len, ) padded_seq_lens = torch.zeros(size=(self.nsa_cache_seqlens.shape[0] + 1,), dtype=torch.int32, device="cuda") # 进行 cumsum 操作 From dd6996ee58ed78965194fd55cfe5ffd5ec994655 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Fri, 6 Mar 2026 07:32:05 +0000 Subject: [PATCH 57/58] fix --- .../attention/nsa/flashmla_sparse.py | 2 ++ .../basemodel/triton_kernel/gen_nsa_ks_ke.py | 22 +++++++++++++++---- 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/lightllm/common/basemodel/attention/nsa/flashmla_sparse.py b/lightllm/common/basemodel/attention/nsa/flashmla_sparse.py index 1e549d59f8..1cc2665c60 100644 --- a/lightllm/common/basemodel/attention/nsa/flashmla_sparse.py +++ b/lightllm/common/basemodel/attention/nsa/flashmla_sparse.py @@ -53,6 +53,7 @@ def init_state(self): req_to_token_index=self.infer_state.req_manager.req_to_token_indexs, q_token_num=self.infer_state.total_token_num - self.infer_state.prefix_total_token_num, ragged_mem_index=self.ragged_mem_index, + hold_req_idx=self.infer_state.req_manager.HOLD_REQUEST_ID, ) return @@ -131,6 +132,7 @@ def init_state(self): req_to_token_index=self.infer_state.req_manager.req_to_token_indexs, q_token_num=self.infer_state.b_seq_len.shape[0], ragged_mem_index=self.ragged_mem_index, + hold_req_idx=self.infer_state.req_manager.HOLD_REQUEST_ID, ) self.nsa_cache_seqlens = torch.minimum( torch.full(size=(self.infer_state.batch_size,), fill_value=2048, dtype=torch.int32, device="cuda"), diff --git a/lightllm/common/basemodel/triton_kernel/gen_nsa_ks_ke.py b/lightllm/common/basemodel/triton_kernel/gen_nsa_ks_ke.py index 27aa5f8e35..4edad6177a 100644 --- a/lightllm/common/basemodel/triton_kernel/gen_nsa_ks_ke.py +++ b/lightllm/common/basemodel/triton_kernel/gen_nsa_ks_ke.py @@ -91,12 +91,16 @@ def gen_nsa_ks_ke( req_to_token_index: torch.Tensor, q_token_num: int, ragged_mem_index: torch.Tensor, + hold_req_idx: int = -1, ): + """ + hold_req_idx 这是一个特殊req idx,主要是用于padding 请求数量时使用,所以其处理存在特殊性。 + """ batch_size = b_seq_len.shape[0] ks = torch.empty((q_token_num,), dtype=torch.int32, device=b_seq_len.device) ke = torch.empty((q_token_num,), dtype=torch.int32, device=b_seq_len.device) lengths = torch.empty((q_token_num,), dtype=torch.int32, device=b_seq_len.device) - b_same_req_mark = gen_same_req_mark(b_req_idx) + b_same_req_mark = gen_same_req_mark(b_req_idx, hold_req_idx=hold_req_idx) _gen_nsa_ks_ke[(batch_size,)]( b_seq_len=b_seq_len, @@ -117,9 +121,14 @@ def gen_nsa_ks_ke( @triton.jit -def _gen_same_req_mark(b_req_idx, b_same_req_mark, BLOCK_SIZE: tl.constexpr): +def _gen_same_req_mark(b_req_idx, b_same_req_mark, hold_req_idx, BLOCK_SIZE: tl.constexpr): cur_index = tl.program_id(0) cur_req_idx = tl.load(b_req_idx + cur_index) + # hold req idx 可能重复,但是单独成组。 + if cur_req_idx == hold_req_idx: + tl.store(b_same_req_mark + cur_index, 1) + return + off = tl.arange(0, BLOCK_SIZE) pre_req_idxs = tl.load(b_req_idx + off, (off < cur_index) & (off < tl.num_programs(0)), other=-1) after_req_idxs = tl.load(b_req_idx + off, (off > cur_index) & (off < tl.num_programs(0)), other=-1) @@ -134,9 +143,11 @@ def _gen_same_req_mark(b_req_idx, b_same_req_mark, BLOCK_SIZE: tl.constexpr): @torch.no_grad() -def gen_same_req_mark(b_req_idx: torch.Tensor): +def gen_same_req_mark(b_req_idx: torch.Tensor, hold_req_idx: int = -1): """ b_req_idx: torch.Tensor + hold_req_idx: int default is -1, hold_req_idx is used to pad batch size to cuda graph batch size, + so need special handle. out: torch.Tensor demo: @@ -147,6 +158,9 @@ def gen_same_req_mark(b_req_idx: torch.Tensor): b_same_req_mark = torch.empty((batch_size,), dtype=torch.int32, device=b_req_idx.device) b_same_req_mark.fill_(0) _gen_same_req_mark[(batch_size,)]( - b_req_idx=b_req_idx, b_same_req_mark=b_same_req_mark, BLOCK_SIZE=triton.next_power_of_2(batch_size) + b_req_idx=b_req_idx, + b_same_req_mark=b_same_req_mark, + hold_req_idx=hold_req_idx, + BLOCK_SIZE=triton.next_power_of_2(batch_size), ) return b_same_req_mark From 0a28a595b1a9af4592e550a0ec82a0d7c5182c6f Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Fri, 6 Mar 2026 09:42:34 +0000 Subject: [PATCH 58/58] pad deepseekv3.2 headdim --- .../deepseek3_2mem_manager.py | 49 +++++++++++------- .../layer_infer/transformer_layer_infer.py | 50 +++++++++++++------ 2 files changed, 66 insertions(+), 33 deletions(-) diff --git a/lightllm/common/kv_cache_mem_manager/deepseek3_2mem_manager.py b/lightllm/common/kv_cache_mem_manager/deepseek3_2mem_manager.py index 679a07bde4..fbf9f88c84 100644 --- a/lightllm/common/kv_cache_mem_manager/deepseek3_2mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/deepseek3_2mem_manager.py @@ -1,25 +1,36 @@ import torch - +from typing import Any from lightllm.common.kv_cache_mem_manager.deepseek2_mem_manager import Deepseek2MemoryManager -class IndexerKSBuffer: - def __init__(self, size: int, head_num: int, head_dim: int, layer_num: int, dtype=torch.uint8): - self.kv_buffer = torch.empty((layer_num, size + 1, head_num, head_dim), dtype=dtype, device="cuda") - - class Deepseek3_2MemoryManager(Deepseek2MemoryManager): def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9): - super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy, mem_fraction) - self.indexer_ks_buffer = IndexerKSBuffer(self.size, 1, 132, layer_num) - - def get_cell_size(self): - return super().get_cell_size() + 132 - - def _free_buffers(self): - super()._free_buffers() - self.indexer_ks_buffer = None - - def resize_mem(self, new_size): - super().resize_mem(new_size) - self.indexer_ks_buffer = IndexerKSBuffer(self.size, 1, 132, self.layer_num) + assert dtype in [torch.bfloat16, torch.float16] + # 因为V3.2 使用了NSA 稀疏的缘故,所以其head_dim 会比原始的kv 多 128 + 4 = 132 个字节 (128 fp8 + 4byte float32 scale), + # 但是为了让整个数组具备16字节对齐,满足一些算子的约束,修改为添加 128 + 16 = 144 个字节, 这 144个字节中,后面132个字节用于 + # 存储真实数据,剩下12个,浪费了,只是占位。 + # 所以在子类中定制为其pad上,对外使用的接口,需要进行重载区别。 + super().__init__(size, dtype, head_num, head_dim + (144 // 2), layer_num, always_copy, mem_fraction) + + def copy_kv_to_mem_manager(self, layer_index: int, mem_index: torch.Tensor, kv: torch.Tensor): + """ + 将每一层生成的kv拷贝到mem manager对应mem_index 位置中 + """ + from ..basemodel.triton_kernel.kv_copy.mla_copy_kv import destindex_copy_kv + + rope_dim = 64 + kv_lora_rank = kv.shape[2] - rope_dim + assert kv_lora_rank + rope_dim == self.kv_buffer.shape[-1] - (144 // 2) + + destindex_copy_kv( + kv[:, :, :kv_lora_rank], + kv[:, :, kv_lora_rank:], + mem_index, + self.kv_buffer[layer_index][:, :, :kv_lora_rank], + self.kv_buffer[layer_index][:, :, kv_lora_rank : (kv_lora_rank + rope_dim)], + ) + return + + def get_att_input_params(self, layer_index: int) -> Any: + kv = self.kv_buffer[layer_index][:, :, : (self.head_dim - (144 // 2))] + return kv diff --git a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py index 0a5023f849..b3c0270b2b 100644 --- a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py @@ -33,10 +33,11 @@ def _get_qkv( [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1 ) q = rmsnorm_forward(q, weight=layer_weight.q_a_layernorm_.weight, eps=self.eps_) - att_state = infer_state.prefill_att_state if infer_state.is_prefill else infer_state.decode_att_state - infer_state.topk_indices = self.indexer.get_indices( - hidden_states=input, q_lora=q, infer_state=infer_state, att_state=att_state, layer_weight=layer_weight - ) + + infer_state.get_topk_indices_params = { + "hidden_states": input, + "q_lora": q, + } q = layer_weight.q_b_proj_.mm(q) cache_kv = cache_kv.view(-1, 1, self.kv_lora_rank + self.qk_rope_head_dim) @@ -70,21 +71,30 @@ def _context_attention_kernel( q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1) q_all = torch.cat([q_nope, q_rope], dim=-1) + # 计算 topk_indices + att_state = infer_state.prefill_att_state + topk_indices = self.indexer.get_indices( + hidden_states=infer_state.get_topk_indices_params["hidden_states"], + q_lora=infer_state.get_topk_indices_params["q_lora"], + infer_state=infer_state, + att_state=att_state, + layer_weight=layer_weight, + ) + del infer_state.get_topk_indices_params + # Use NSA backend for attention computation att_control = AttControl( nsa_prefill=True, nsa_prefill_dict={ - "topk_indices": infer_state.topk_indices, + "topk_indices": topk_indices, "softmax_scale": self.softmax_scale, "kv_lora_rank": self.kv_lora_rank, }, ) - del infer_state.topk_indices - mla_out = infer_state.prefill_att_state.prefill_att( q=q_all, - k=infer_state.mem_manager.kv_buffer[self.layer_num_], + k=infer_state.mem_manager.get_att_input_params(layer_index=self.layer_num_), v=None, att_control=att_control, ) @@ -101,22 +111,31 @@ def _token_attention_kernel( q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :] q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1) + # 计算 topk_indices + att_state = infer_state.decode_att_state + topk_indices = self.indexer.get_indices( + hidden_states=infer_state.get_topk_indices_params["hidden_states"], + q_lora=infer_state.get_topk_indices_params["q_lora"], + infer_state=infer_state, + att_state=att_state, + layer_weight=layer_weight, + ) + del infer_state.get_topk_indices_params + # Use NSA backend for attention computation att_control = AttControl( nsa_decode=True, nsa_decode_dict={ - "topk_indices": infer_state.topk_indices, + "topk_indices": topk_indices, "softmax_scale": self.softmax_scale, "kv_lora_rank": self.kv_lora_rank, "qk_rope_head_dim": self.qk_rope_head_dim, }, ) - del infer_state.topk_indices - o_tensor = infer_state.decode_att_state.decode_att( q=(q_nope, q_rope), - k=infer_state.mem_manager.kv_buffer[self.layer_num_], + k=infer_state.mem_manager.get_att_input_params(layer_index=self.layer_num_), v=None, att_control=att_control, ) @@ -153,7 +172,10 @@ def get_indices( k_fp8, k_scale = act_quant(k, self.block_size, self.scale_fmt) destindex_copy_indexer_ks( - k_fp8, k_scale, infer_state.mem_index, infer_state.mem_manager.indexer_ks_buffer.kv_buffer[self.layer_idx_] + K_fp8=k_fp8, + K_scale=k_scale, + DestLoc=infer_state.mem_index, + O_buffer=infer_state.mem_manager.kv_buffer[self.layer_idx_].view(dtype=torch.uint8)[:, :, -132:], ) weights = layer_weight.weights_proj_.mm(hidden_states) * self.index_n_heads_scale @@ -169,7 +191,7 @@ def get_indices( mtp_step = get_env_start_args().mtp_step # Use efficient Triton kernel to extract FP8 keys and scales from buffer k_fp8_, k_scale_ = extract_indexer_ks( - I_buffer=infer_state.mem_manager.indexer_ks_buffer.kv_buffer[self.layer_idx_], + I_buffer=infer_state.mem_manager.kv_buffer[self.layer_idx_].view(dtype=torch.uint8)[:, :, -132:], b_seq_len=infer_state.b_seq_len, b_req_idx=infer_state.b_req_idx, req_to_token_indexs=infer_state.req_manager.req_to_token_indexs,