Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
c24e04a
support deepseek v3.2
sufubao Nov 5, 2025
8926ac9
fix
sufubao Nov 5, 2025
d1956cc
fix
sufubao Nov 6, 2025
4f8a747
fix
sufubao Nov 7, 2025
2cf0833
need fix
sufubao Nov 7, 2025
bb8f087
run like deepseek v3
sufubao Nov 10, 2025
3f3a656
fix
sufubao Nov 10, 2025
19c9128
fix
sufubao Nov 10, 2025
303ca19
fix
sufubao Nov 10, 2025
3b3748a
fix
sufubao Nov 10, 2025
5d8119e
can run without cudagraph
sufubao Nov 10, 2025
bcafce3
fix cudagraph
sufubao Nov 10, 2025
66138e5
can run
sufubao Dec 26, 2025
17204e7
abstract NSA attention into backend framework
sufubao Feb 2, 2026
b52f3db
rebase
Mar 4, 2026
e51d4ce
fix
sufubao Feb 4, 2026
7d8be57
fix
sufubao Feb 4, 2026
da09156
fix
sufubao Feb 4, 2026
512e32c
rebase
Mar 4, 2026
2ca681c
fix
Feb 4, 2026
f412e42
deepseekv32 model_type condition
shihaobai Feb 4, 2026
d72f085
fix v1 streaming
Feb 5, 2026
dd1a067
exclude_none
Feb 5, 2026
f352854
fix deepseekv3.2
Mar 3, 2026
579a47e
fix
Mar 3, 2026
8b6b5b7
fix
Mar 3, 2026
c0157d3
fix
Mar 3, 2026
f92716a
fix
Mar 3, 2026
879e907
fix
hiworldwzj Mar 3, 2026
68efbdc
fix
hiworldwzj Mar 3, 2026
7d8e54d
fix
hiworldwzj Mar 3, 2026
932b402
Fix
hiworldwzj Mar 3, 2026
f88503e
fix
Mar 4, 2026
54f8d5c
fix
Mar 4, 2026
bc618f2
fix
Mar 4, 2026
c196bca
fix
Mar 4, 2026
3de599b
fix
Mar 4, 2026
2fb0728
fix
Mar 4, 2026
51d7e37
fix
Mar 4, 2026
2f6b3a5
fix
Mar 4, 2026
a37dc86
fix
Mar 4, 2026
adb7cad
fix
Mar 4, 2026
ab0ce7a
fix
Mar 4, 2026
e069e9a
fix
Mar 4, 2026
abc836d
fix
Mar 4, 2026
63eb4e1
fix
Mar 4, 2026
c8d3b42
fix
Mar 4, 2026
1fc7a10
fix
hiworldwzj Mar 4, 2026
ad6c8de
add comments
hiworldwzj Mar 4, 2026
6e4691c
fix
hiworldwzj Mar 4, 2026
1e4a467
fix
hiworldwzj Mar 4, 2026
f8241e9
fix
hiworldwzj Mar 4, 2026
bf13746
fix
hiworldwzj Mar 4, 2026
3d5eacd
fix
Mar 5, 2026
554b117
fix
hiworldwzj Mar 5, 2026
cb0b6e2
fix
hiworldwzj Mar 6, 2026
dd6996e
fix
hiworldwzj Mar 6, 2026
0a28a59
pad deepseekv3.2 headdim
hiworldwzj Mar 6, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ FROM nvidia/cuda:${CUDA_VERSION}-cudnn-devel-ubuntu22.04

ARG PYTHON_VERSION=3.10
ARG MAMBA_VERSION=24.7.1-0
ARG VLLM_VERSION=0.11.0
ARG VLLM_VERSION=0.16.0
ARG TARGETPLATFORM
ARG ENABLE_DEEPEP=1
ARG ENABLE_NIXL=1
Expand All @@ -20,14 +20,15 @@ RUN chmod 777 -R /tmp && \
curl \
g++ \
make \
git && \
git \
wget && \
rm -rf /var/lib/apt/lists/*

RUN case ${TARGETPLATFORM} in \
"linux/arm64") MAMBA_ARCH=aarch64 ;; \
*) MAMBA_ARCH=x86_64 ;; \
esac && \
curl -fsSL -o ~/mambaforge.sh "https://github.com/conda-forge/miniforge/releases/download/${MAMBA_VERSION}/Mambaforge-${MAMBA_VERSION}-Linux-${MAMBA_ARCH}.sh" && \
wget -O ~/mambaforge.sh "https://github.com/conda-forge/miniforge/releases/download/${MAMBA_VERSION}/Mambaforge-${MAMBA_VERSION}-Linux-${MAMBA_ARCH}.sh" && \
bash ~/mambaforge.sh -b -p /opt/conda && \
rm ~/mambaforge.sh

Expand All @@ -44,7 +45,6 @@ COPY ./requirements.txt /lightllm/requirements.txt
RUN pip install -U pip
RUN pip install -r /lightllm/requirements.txt --no-cache-dir
RUN pip install --no-cache-dir vllm==${VLLM_VERSION}
RUN pip install https://github.com/ModelTC/LightKernel/releases/download/v1.0.1/lightllm_kernel-0.1.0-cp310-cp310-linux_x86_64.whl

RUN apt-get update && apt-get install -y libnuma-dev && rm -rf /var/lib/apt/lists/*

Expand Down
3 changes: 2 additions & 1 deletion docker/scripts/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -100,5 +100,6 @@ DOCKER_BUILDKIT=1 docker build -f docker/Dockerfile \
--build-arg ENABLE_DEEPEP="${ENABLE_DEEPEP}" \
--build-arg ENABLE_NIXL="${ENABLE_NIXL}" \
--build-arg ENABLE_CACHE="${ENABLE_CACHE}" \
-t "${IMAGE_PREFIX}:${IMAGE_TAG}" .
--progress=plain \
-t "${IMAGE_PREFIX}:${IMAGE_TAG}" .

5 changes: 5 additions & 0 deletions lightllm/common/basemodel/attention/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,14 @@
from .flashinfer.fp import FlashInferAttBackend
from .flashinfer.mla import MlaFlashInferAttBackend

# NSA backend
from .nsa.flashmla_sparse import NsaFlashMlaSparseAttBackend

from .create_utils import (
get_prefill_att_backend_class,
get_decode_att_backend_class,
get_mla_prefill_att_backend_class,
get_mla_decode_att_backend_class,
get_nsa_prefill_att_backend_class,
get_nsa_decode_att_backend_class,
)
5 changes: 5 additions & 0 deletions lightllm/common/basemodel/attention/base_att.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,11 @@ class AttControl:
mla_prefill_dict: Dict = None
mla_decode: bool = False
mla_decode_dict: Dict = None
# nsa (native sparse attention) 专用传参项
nsa_prefill: bool = False
nsa_prefill_dict: Dict = None
nsa_decode: bool = False
nsa_decode_dict: Dict = None


@dataclass
Expand Down
46 changes: 37 additions & 9 deletions lightllm/common/basemodel/attention/create_utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
"""Attention backend selection utilities."""

import os
import torch
from lightllm.utils.envs_utils import get_env_start_args
from lightllm.utils.log_utils import init_logger
from lightllm.utils.backend_validator import validate
from typing import Dict
from .base_att import BaseAttBackend
from .triton.fp import TritonAttBackend
from .triton.int4kv import Int4kvTritonAttBackend
Expand All @@ -16,6 +14,7 @@
from .flashinfer.fp8 import Fp8FlashInferAttBackend
from .flashinfer.fp import FlashInferAttBackend
from .flashinfer.mla import MlaFlashInferAttBackend
from .nsa.flashmla_sparse import NsaFlashMlaSparseAttBackend

logger = init_logger(__name__)

Expand Down Expand Up @@ -46,16 +45,25 @@
},
}

nsa_data_type_to_backend = {
"None": {
"flashmla_sparse": NsaFlashMlaSparseAttBackend,
# Future backends: "fa3", "tilelang", "aiter"
},
}


def _auto_select_backend(
llm_dtype: str, is_mla: bool = False, priority_list: list = ["fa3", "flashinfer", "triton"]
llm_dtype: str,
kv_type_to_backend: Dict[str, Dict[str, BaseAttBackend]],
priority_list: list = ["fa3", "flashinfer", "triton"],
) -> type:
"""Auto-select the best available backend with validation.

Priority: FA3 > FlashInfer > Triton
Each backend is validated in a subprocess with ground truth checks.
"""
backend_map = mla_data_type_to_backend if is_mla else data_type_to_backend
backend_map = kv_type_to_backend

for backend_name in priority_list:
if validate(backend_name):
Expand All @@ -74,7 +82,7 @@ def get_prefill_att_backend_class(index=0, priority_list: list = ["fa3", "flashi
if backend_str != "auto":
return data_type_to_backend[llm_dtype][backend_str]
else:
return _auto_select_backend(llm_dtype, is_mla=False, priority_list=priority_list)
return _auto_select_backend(llm_dtype, kv_type_to_backend=data_type_to_backend, priority_list=priority_list)


def get_decode_att_backend_class(index=0, priority_list: list = ["fa3", "flashinfer", "triton"]) -> BaseAttBackend:
Expand All @@ -84,7 +92,7 @@ def get_decode_att_backend_class(index=0, priority_list: list = ["fa3", "flashin
if backend_str != "auto":
return data_type_to_backend[llm_dtype][backend_str]
else:
return _auto_select_backend(llm_dtype, is_mla=False, priority_list=priority_list)
return _auto_select_backend(llm_dtype, kv_type_to_backend=data_type_to_backend, priority_list=priority_list)


def get_mla_prefill_att_backend_class(index=0, priority_list: list = ["fa3", "flashinfer", "triton"]) -> BaseAttBackend:
Expand All @@ -94,7 +102,7 @@ def get_mla_prefill_att_backend_class(index=0, priority_list: list = ["fa3", "fl
if backend_str != "auto":
return mla_data_type_to_backend[llm_dtype][backend_str]
else:
return _auto_select_backend(llm_dtype, is_mla=True, priority_list=priority_list)
return _auto_select_backend(llm_dtype, kv_type_to_backend=mla_data_type_to_backend, priority_list=priority_list)


def get_mla_decode_att_backend_class(index=0, priority_list: list = ["fa3", "flashinfer", "triton"]) -> BaseAttBackend:
Expand All @@ -104,4 +112,24 @@ def get_mla_decode_att_backend_class(index=0, priority_list: list = ["fa3", "fla
if backend_str != "auto":
return mla_data_type_to_backend[llm_dtype][backend_str]
else:
return _auto_select_backend(llm_dtype, is_mla=True, priority_list=priority_list)
return _auto_select_backend(llm_dtype, kv_type_to_backend=mla_data_type_to_backend, priority_list=priority_list)


def get_nsa_prefill_att_backend_class(index=0, priority_list: list = ["flashmla_sparse"]) -> BaseAttBackend:
args = get_env_start_args()
llm_dtype = args.llm_kv_type
backend_str = args.llm_prefill_att_backend[index]
if backend_str != "auto":
return nsa_data_type_to_backend[llm_dtype][backend_str]
else:
return _auto_select_backend(llm_dtype, kv_type_to_backend=nsa_data_type_to_backend, priority_list=priority_list)


def get_nsa_decode_att_backend_class(index=0, priority_list: list = ["flashmla_sparse"]) -> BaseAttBackend:
args = get_env_start_args()
llm_dtype = args.llm_kv_type
backend_str = args.llm_decode_att_backend[index]
if backend_str != "auto":
return nsa_data_type_to_backend[llm_dtype][backend_str]
else:
return _auto_select_backend(llm_dtype, kv_type_to_backend=nsa_data_type_to_backend, priority_list=priority_list)
13 changes: 13 additions & 0 deletions lightllm/common/basemodel/attention/nsa/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
"""NSA (Native Sparse Attention) backend implementations."""

from .flashmla_sparse import (
NsaFlashMlaSparseAttBackend,
NsaFlashMlaSparsePrefillAttState,
NsaFlashMlaSparseDecodeAttState,
)

__all__ = [
"NsaFlashMlaSparseAttBackend",
"NsaFlashMlaSparsePrefillAttState",
"NsaFlashMlaSparseDecodeAttState",
]
192 changes: 192 additions & 0 deletions lightllm/common/basemodel/attention/nsa/flashmla_sparse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
# Adapted from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/attention/nsa_backend.py
# Uses sgl_kernel.flash_mla and sgl_kernel.flash_attn from the sglang kernel library.

import dataclasses
import torch
from typing import Tuple, TYPE_CHECKING

from ..base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl
from lightllm.utils.dist_utils import get_current_device_id

if TYPE_CHECKING:
from lightllm.common.basemodel.infer_struct import InferStateInfo


class NsaFlashMlaSparseAttBackend(BaseAttBackend):
def __init__(self, model):
super().__init__(model=model)
device = get_current_device_id()
self.ragged_mem_buffers = [
torch.empty(model.graph_max_batch_size * model.max_seq_length, dtype=torch.int32, device=device)
for _ in range(2)
]

def create_att_prefill_state(self, infer_state: "InferStateInfo") -> "NsaFlashMlaSparsePrefillAttState":
return NsaFlashMlaSparsePrefillAttState(backend=self, infer_state=infer_state)

def create_att_decode_state(self, infer_state: "InferStateInfo") -> "NsaFlashMlaSparseDecodeAttState":
return NsaFlashMlaSparseDecodeAttState(backend=self, infer_state=infer_state)


@dataclasses.dataclass
class NsaFlashMlaSparsePrefillAttState(BasePrefillAttState):
"""Prefill attention state for NSA using flash_mla_sparse_fwd."""

ks: torch.Tensor = None
ke: torch.Tensor = None
lengths: torch.Tensor = None
ragged_mem_index: torch.Tensor = None

def init_state(self):
self.backend: NsaFlashMlaSparseAttBackend = self.backend
self.ragged_mem_index = torch.empty(
self.infer_state.total_token_num,
dtype=torch.int32,
device=get_current_device_id(),
)
from lightllm.common.basemodel.triton_kernel.gen_nsa_ks_ke import gen_nsa_ks_ke

self.ks, self.ke, self.lengths = gen_nsa_ks_ke(
b_seq_len=self.infer_state.b_seq_len,
b_q_seq_len=self.infer_state.b_q_seq_len,
b_req_idx=self.infer_state.b_req_idx,
req_to_token_index=self.infer_state.req_manager.req_to_token_indexs,
q_token_num=self.infer_state.total_token_num - self.infer_state.prefix_total_token_num,
ragged_mem_index=self.ragged_mem_index,
hold_req_idx=self.infer_state.req_manager.HOLD_REQUEST_ID,
)
return

def prefill_att(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
att_control: AttControl = AttControl(),
alloc_func=torch.empty,
) -> torch.Tensor:
assert att_control.nsa_prefill, "nsa_prefill must be True for NSA prefill attention"
assert att_control.nsa_prefill_dict is not None, "nsa_prefill_dict is required"

return self._nsa_prefill_att(q=q, kv=k, att_control=att_control)

def _nsa_prefill_att(
self,
q: torch.Tensor,
kv: torch.Tensor,
att_control: AttControl,
) -> torch.Tensor:
from sgl_kernel.flash_mla import flash_mla_sparse_fwd

nsa_dict = att_control.nsa_prefill_dict
topk_indices = nsa_dict["topk_indices"]
softmax_scale = nsa_dict["softmax_scale"]
kv_lora_rank = nsa_dict["kv_lora_rank"]

if topk_indices.ndim == 2:
topk_indices = topk_indices.unsqueeze(1)

mla_out, _, _ = flash_mla_sparse_fwd(
q=q,
kv=kv,
indices=topk_indices,
sm_scale=softmax_scale,
d_v=kv_lora_rank,
)
return mla_out


@dataclasses.dataclass
class NsaFlashMlaSparseDecodeAttState(BaseDecodeAttState):

ks: torch.Tensor = None
ke: torch.Tensor = None
length: torch.Tensor = None
ragged_mem_index: torch.Tensor = None
nsa_cache_seqlens: torch.Tensor = None
nsa_cu_seqlens_k_new: torch.Tensor = None

def init_state(self):
self.backend: NsaFlashMlaSparseAttBackend = self.backend
model = self.backend.model
use_cuda_graph = (
self.infer_state.batch_size <= model.graph_max_batch_size
and self.infer_state.max_kv_seq_len <= model.graph_max_len_in_batch
)

if use_cuda_graph:
self.ragged_mem_index = self.backend.ragged_mem_buffers[self.infer_state.microbatch_index]
else:
self.ragged_mem_index = torch.empty(
self.infer_state.total_token_num,
dtype=torch.int32,
device=get_current_device_id(),
)

from lightllm.common.basemodel.triton_kernel.gen_nsa_ks_ke import gen_nsa_ks_ke

self.ks, self.ke, self.lengths = gen_nsa_ks_ke(
b_seq_len=self.infer_state.b_seq_len,
b_q_seq_len=self.infer_state.b_q_seq_len,
b_req_idx=self.infer_state.b_req_idx,
req_to_token_index=self.infer_state.req_manager.req_to_token_indexs,
q_token_num=self.infer_state.b_seq_len.shape[0],
ragged_mem_index=self.ragged_mem_index,
hold_req_idx=self.infer_state.req_manager.HOLD_REQUEST_ID,
)
self.nsa_cache_seqlens = torch.minimum(
torch.full(size=(self.infer_state.batch_size,), fill_value=2048, dtype=torch.int32, device="cuda"),
self.infer_state.b_seq_len,
)
padded_seq_lens = torch.zeros(size=(self.nsa_cache_seqlens.shape[0] + 1,), dtype=torch.int32, device="cuda")
# 进行 cumsum 操作
padded_seq_lens[1:].copy_(self.nsa_cache_seqlens, non_blocking=True)
self.nsa_cu_seqlens_k_new = padded_seq_lens.cumsum(dim=0, dtype=torch.int32)

def decode_att(
self,
q: Tuple[torch.Tensor, torch.Tensor],
k: torch.Tensor,
v: torch.Tensor,
att_control: AttControl = AttControl(),
alloc_func=torch.empty,
) -> torch.Tensor:
assert att_control.nsa_decode, "nsa_decode must be True for NSA decode attention"
assert att_control.nsa_decode_dict is not None, "nsa_decode_dict is required"

return self._nsa_decode_att(q=q, kv=k, att_control=att_control)

def _nsa_decode_att(
self,
q: Tuple[torch.Tensor, torch.Tensor],
kv: torch.Tensor,
att_control: AttControl,
) -> torch.Tensor:
from sgl_kernel.flash_attn import flash_attn_with_kvcache

nsa_dict = att_control.nsa_decode_dict
topk_indices = nsa_dict["topk_indices"]
softmax_scale = nsa_dict["softmax_scale"]
kv_lora_rank = nsa_dict["kv_lora_rank"]
qk_rope_head_dim = nsa_dict["qk_rope_head_dim"]

q_nope, q_rope = q

# Extract k_rope and kv_nope from the KV buffer
k_rope = kv[:, :, -qk_rope_head_dim:].view(-1, 1, 1, qk_rope_head_dim)
kv_nope = kv[:, :, :-qk_rope_head_dim].view(-1, 1, 1, kv_lora_rank)

o_tensor = flash_attn_with_kvcache(
q=q_rope,
k_cache=k_rope,
v_cache=kv_nope,
qv=q_nope,
page_table=topk_indices,
cache_seqlens=self.nsa_cache_seqlens,
cu_seqlens_q=self.infer_state.b1_cu_q_seq_len,
cu_seqlens_k_new=self.nsa_cu_seqlens_k_new,
max_seqlen_q=self.infer_state.max_q_seq_len,
softmax_scale=softmax_scale,
causal=True,
)
return o_tensor
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def _init_parallel_params(self):
self.split_inter_size = self.moe_intermediate_size // self.tp_world_size_
if self.enable_ep_moe:
assert self.num_fused_shared_experts == 0, "num_fused_shared_experts must be 0 when enable_ep_moe"
logger.info(
logger.debug(
f"global_rank {self.global_rank_} layerindex {self.layer_num_} "
f"redundancy_expertids: {self.redundancy_expert_ids}"
)
Expand Down
Loading
Loading