"The best optimization is the one you don't have to think about."
ejKernel is a production-grade kernel library for JAX that provides highly optimized implementations of deep learning operations with automatic multi-backend support. The library features a sophisticated configuration management system with autotuning, comprehensive type safety, and seamless execution across GPUs, TPUs, and CPUs.
Note
eJkernel contains no AI-generated code. All kernels, modules, and core logic are manually designed and implemented by human developers. AI tooling (Opus 4.5) is used exclusively for documentation, which may therefore contain minor inaccuracies. There is no “vibe coding” or automated code generation anywhere in the codebase.
- Key Features
- Installation
- Quick Start
- Architecture Overview
- Supported Operations
- Advanced Usage
- Development
- Testing
- Contributing
- Citation
- License
- 7-Tier Configuration System: Override → Overlay → Memory Cache → Persistent Cache → Autotune → Heuristics → Error
- Automatic Platform Detection: Seamlessly selects optimal implementation based on hardware
- Priority-Based Registry: Multi-backend support with intelligent fallback mechanisms
- Device Fingerprinting: Hardware-specific configuration caching for optimal performance
- 30+ Deep Learning Operations: Flash Attention v2, Flash MLA, Ring Attention, Page Attention, Block Sparse, GLA, Lightning, Quantized MatMul, State Space Models (Mamba), RWKV (v4/v6/v7), and more
- Memory Efficiency: Custom VJP implementations with O(N) memory complexity for attention
- Distributed Support: Full shard_map integration for model and data parallelism
- Mixed Precision: Comprehensive dtype support with automatic gradient conversion
- Type Safety: Full jaxtyping annotations with runtime validation via beartype
- Comprehensive Testing: Cross-backend validation, performance benchmarks, integration tests
- Atomic Persistence: Thread-safe configuration storage with automatic optimization
- Profiling Integration: Built-in support for JAX profiling and performance monitoring
pip install ejkernel# GPU Support (CUDA)
pip install ejkernel[cuda]
# TPU Support
pip install ejkernel[tpu]
# Development Installation
git clone https://github.com/erfanzar/ejkernel.git
cd ejkernel
pip install -e ".[dev]"- Python 3.11-3.13
- JAX >= 0.9.0
- Triton == 3.6.0 (for GPU)
- nvidia-cutlass-dsl >= 4.4.0 (optional, for CuTe DSL kernels)
- jax-tvm-ffi == 0.1.2 (optional, for CuTe TVM-FFI primitive path)
- jaxtyping >= 0.3.2
- beartype >= 0.22.2
- pydantic >= 2.11.10
import jax.numpy as jnp
from ejkernel.modules import flash_attention
# Basic usage - automatic configuration selection
output = flash_attention(
query, key, value,
causal=True,
dropout_prob=0.1
)
# With advanced features
output = flash_attention(
query, key, value,
causal=True,
sliding_window=128, # Local attention window
logits_soft_cap=30.0, # Gemma-2 style soft capping
attention_mask=mask, # Custom attention pattern
)from ejkernel.modules import FlashAttentionConfig
from ejkernel.ops.utils.datacarrier import FwdParams, BwdParams
# Create optimized configuration
config = FlashAttentionConfig(
fwd_params=FwdParams(
q_blocksize=256,
kv_blocksize=256,
num_warps=8,
num_stages=2
),
bwd_params=BwdParams(
q_blocksize=128,
kv_blocksize=128,
num_warps=4
),
platform="triton", # Force specific backend
backend="gpu"
)
output = flash_attention(query, key, value, cfg=config)from ejkernel import kernel_registry, Platform, Backend
# Get specific implementation
kernel = kernel_registry.get(
algorithm="flash_attention",
platform=Platform.TRITON,
backend=Backend.GPU
)
# Direct execution
output = kernel(query, key, value, causal=True)import jax
from jax.sharding import Mesh, PartitionSpec as P
from ejkernel.modules import flash_attention
# Setup mesh for distributed execution
devices = jax.devices()
mesh = Mesh(devices, axis_names=("data", "model"))
# Run distributed attention
output = flash_attention(
query, key, value,
causal=True,
mesh=mesh,
in_specs=(P("data", None), P("data", None), P("data", None)),
out_specs=P("data", None)
)ejKernel employs a sophisticated layered architecture that separates concerns while maintaining high performance:
┌─────────────────────────────────────────────────────┐
│ Public API (modules/) │
│ Simple functions with sensible defaults │
├─────────────────────────────────────────────────────┤
│ Operations Layer (ops/) │
│ Configuration management, autotuning, caching │
├─────────────────────────────────────────────────────┤
│ Kernel Registry (kernels/) │
│ Platform routing, signature validation │
├─────────────────────────────────────────────────────┤
│ Backend Implementations (kernels/\_\*) │
│ Triton, CuTe, Pallas, XLA, CUDA kernels │
└─────────────────────────────────────────────────────┘The registry provides automatic platform-specific kernel selection:
@kernel_registry.register("my_operation", Platform.TRITON, Backend.GPU, priority=100)
def my_operation_gpu(x, y):
# GPU-optimized implementation
pass
@kernel_registry.register("my_operation", Platform.XLA, Backend.ANY, priority=50)
def my_operation_fallback(x, y):
# Universal fallback
pass
# Automatic selection based on available hardware
impl = kernel_registry.get("my_operation")Multi-tier configuration system with intelligent fallback:
class ConfigSelectorChain:
"""
Selection hierarchy:
1. Override - Explicit user configuration
2. Overlay - Temporary context overrides
3. Memory Cache - In-memory lookup
4. Persistent Cache - Disk-based storage
5. Autotune - Performance benchmarking
6. Heuristics - Intelligent defaults
7. Error - Clear failure message
"""All performance-critical kernels implement memory-efficient gradients:
@jax.custom_vjp
def kernel_with_custom_grad(inputs):
return forward(inputs)
def kernel_fwd(inputs):
output, residuals = forward_with_residuals(inputs)
return output, residuals
def kernel_bwd(residuals, grad_output):
return efficient_backward(residuals, grad_output)
kernel_with_custom_grad.defvjp(kernel_fwd, kernel_bwd)| Algorithm | Description | Memory | Key Features |
|---|---|---|---|
| Flash Attention v2 | Memory-efficient exact attention | O(N) | Causal masking, dropout, sliding windows, soft capping |
| Ring Attention | Distributed sequence parallelism | O(N/P) | Ultra-long sequences, communication overlap, XLA single-device fallback |
| Page Attention | KV-cache optimized inference | O(N) | Block-wise memory, continuous batching |
| Block Sparse Attention | Configurable sparse patterns | O(N√N) | Local+global, custom patterns |
| GLA | Gated Linear Attention | O(N) | Linear complexity, gated updates |
| Lightning Attention | Layer-dependent decay | O(N) | Exponential moving average |
| MLA | Multi-head Latent Attention | O(N) | Compressed KV representation |
| Ragged Page Attention v2 | Variable-length paged attention | O(N) | Ragged sequences with page caching |
| Ragged Page Attention v3 | Enhanced ragged page attention | O(N) | Attention sinks support, improved handling |
| Ragged Decode Attention | Variable-length decoding | O(N) | Efficient batched inference |
| Kernel Delta Attention | Delta-rule linear attention | O(N) | Linear complexity, delta updates, decay control |
| Unified Attention | vLLM-style paged attention | O(N) | Segmented 3D decode kernel |
| Prefill Page Attention | Page attention prefill phase | O(N) | Separate prefill handling |
| Decode Attention | Single-token decode attention | O(N) | Optimized single-step decoding |
| Chunked Prefill Paged Decode | Combined prefill + decode | O(N) | Chunked prefill with paged KV cache decode |
| Flash MLA | Multi-head Latent Attention | O(N) | Low-rank KV compression, memory-efficient inference |
| Scaled Dot-Product Attention | Standard attention | O(N²) | Basic reference implementation |
| Operation | Description | Key Features |
|---|---|---|
| RWKV-4 | Time-mix recurrence | Numerically stable (α,β,ε) state, O(N) memory |
| RWKV-6 | Multi-head linear attention | Variable-length packing, reverse mode, O(N) memory |
| RWKV-7 | DPLR (Diagonal + Low-Rank) recurrence | (a,b) parameterization, state-space inspired |
| RWKV-7 Mul | Multiplicative RWKV-7 variant | (kk,a) reparameterization for optimized kernels |
| Operation | Description | Use Case |
|---|---|---|
| Grouped MatMul | Efficient batched matrix operations | Expert models, MoE |
| Grouped MatMul v2 | Enhanced with shard_map support | Distributed expert models |
| Mean Pooling | Variable-length sequence aggregation | Sentence embeddings |
| Recurrent | Optimized RNN/LSTM/GRU operations | Sequential modeling |
| Native Sparse | Block-sparse matrix computations | Sparse attention patterns |
| Quantized MatMul | Multi-mode quantized matmul (affine, NF4, MXFP4/8, NVFP4/8) | Low-bit inference |
| Operation | Description | Key Features |
|---|---|---|
| State Space v1 | Mamba1-style SSM | 2D A matrix, separate dt_proj, custom VJP for memory efficiency |
| State Space v2 | Mamba2-style SSM | Per-head scalar A, n_groups for parameter grouping, optional gated RMSNorm |
| Operation | Triton (GPU) | CUTE (GPU) | CUDA (GPU) | Pallas (TPU) | XLA (Universal) |
|---|---|---|---|---|---|
| Flash Attention v2 | ✅ | ✅ | ✅ | ✅ | ✅ |
| Flash MLA | ✅ | - | - | - | ✅ |
| Ring Attention | ✅ | - | - | ✅ | ✅ |
| Page Attention | ✅ | - | - | ✅ | ✅ |
| Block Sparse Attention | ✅ | - | ✅ | ✅ | ✅ |
| Decode Attention | ✅ | - | - | - | ✅ |
| Chunked Prefill Paged Decode | ✅ | ✅ | - | - | ✅ |
| Ragged Page Attention v2 | ✅ | - | - | ✅ | ✅ |
| Ragged Page Attention v3 | ✅ | - | ✅ | ✅ | ✅ |
| Ragged Decode Attention | ✅ | - | - | ✅ | ✅ |
| GLA | ✅ | - | - | - | ✅ |
| Lightning Attention | ✅ | - | - | - | ✅ |
| Recurrent | ✅ | - | - | - | ✅ |
| Mean Pooling | ✅ | - | - | - | ✅ |
| Grouped MatMul | - | - | - | ✅ | ✅ |
| Grouped MatMul v2 | - | - | - | ✅ | - |
| Native Sparse Attention | ✅ | - | - | - | ✅ |
| Quantized MatMul | ✅ | ✅ | ✅ | ✅ | ✅ |
| Kernel Delta Attention | - | - | - | - | ✅ |
| Unified Attention | ✅ | ✅ | ✅ | - | ✅ |
| Prefill Page Attention | - | - | - | ✅ | ✅ |
| Scaled Dot-Product Attention | - | - | - | - | ✅ |
| State Space v1 | - | - | - | - | ✅ |
| State Space v2 | - | - | - | - | ✅ |
| RWKV-4 | ✅ | - | - | - | ✅ |
| RWKV-6 | ✅ | - | - | - | ✅ |
| RWKV-7 | ✅ | - | - | - | ✅ |
| RWKV-7 Mul | ✅ | - | - | - | ✅ |
✅ = Production ready | - = Not available
* CuTe backend uses TVM-FFI primitive path with fused kernels. * Quantized MatMul on TPU uses hybrid dispatch (packed Pallas / predecode / XLA fallback). * Distributed matmul ops (all_gather_matmul, reduce_scatter_matmul) intentionally do not perform runtime fallback between distributed backends; choose platform/cfg.platform explicitly.
from ejkernel.modules import page_attention, PageAttentionConfig
# Configure paged attention for inference
config = PageAttentionConfig(
platform="auto",
backend="gpu"
)
output = page_attention(
query=q,
key_cache=k_cache,
value_cache=v_cache,
block_table=block_table,
cache_seqlens=cache_seqlens,
cfg=config
)from ejkernel.modules import ragged_page_attention_v3, RaggedPageAttentionv3Config
# For variable-length sequences with attention sinks
config = RaggedPageAttentionv3Config(
platform="pallas",
backend="tpu"
)
output = ragged_page_attention_v3(
query=q,
key_pages=k_pages,
value_pages=v_pages,
lengths=seq_lengths,
page_indices=page_indices,
cfg=config
)# Force autotuning for optimal configuration
import os
os.environ["EJKERNEL_AUTOTUNE_POLICY"] = "autotune"
os.environ["EJKERNEL_LOG_AUTOTUNE"] = "1"
# Enable profiling
os.environ["EJKERNEL_OPS_STAMP"] = "json" # Detailed metadata
os.environ["EJKERNEL_OPS_RECORD"] = "1" # Record invocationsfrom ejkernel.ops.core import Kernel
from ejkernel.modules.operations.configs import BaseOperationConfig
from dataclasses import dataclass
@dataclass
class MyConfig(BaseOperationConfig):
param1: int = 128
param2: float = 0.1
class MyKernel(Kernel[MyConfig, Array]):
def __init__(self):
super().__init__(op_id="my_kernel")
def run(self, x, cfg: MyConfig):
impl = kernel_registry.get("my_kernel", cfg.platform)
return impl(x, param1=cfg.param1, param2=cfg.param2)
def heuristic_cfg(self, inv):
# Return default configuration
return MyConfig(param1=256)
def candidate_cfgs(self, inv):
# Return autotuning candidates
return [MyConfig(param1=p) for p in [64, 128, 256]]import flax.linen as nn
from ejkernel.modules import flash_attention
class TransformerBlock(nn.Module):
num_heads: int = 8
head_dim: int = 64
@nn.compact
def __call__(self, x, mask=None):
# Project to Q, K, V
q = nn.Dense(self.num_heads * self.head_dim)(x)
k = nn.Dense(self.num_heads * self.head_dim)(x)
v = nn.Dense(self.num_heads * self.head_dim)(x)
# Reshape for attention
shape = (x.shape[0], x.shape[1], self.num_heads, self.head_dim)
q, k, v = map(lambda t: t.reshape(shape), (q, k, v))
# Apply ejKernel Flash Attention
attn_output = flash_attention(
q, k, v,
causal=True,
attention_mask=mask
)
# Project output
return nn.Dense(x.shape[-1])(attn_output.reshape(x.shape))# Clone repository
git clone https://github.com/erfanzar/ejkernel.git
cd ejkernel
# Create virtual environment
python -m venv .venv
source .venv/bin/activate # On Windows: .venv\Scripts\activate
# Install in development mode
pip install -e ".[dev]"
# Install pre-commit hooks
pre-commit installThe project uses:
- black for code formatting (line length: 121)
- ruff for linting
- mypy/pyright for type checking
- pre-commit for automated checks
- Implement the kernel in appropriate backend directory:
# ejkernel/kernels/_triton/my_kernel/_interface.py
@kernel_registry.register("my_kernel", Platform.TRITON, Backend.GPU)
def my_kernel_triton(x, config):
# Implementation
pass- Create module wrapper:
# ejkernel/modules/operations/my_kernel.py
class MyKernel(Kernel[MyKernelConfig, Array]):
# Module implementation
pass- Add tests:
# test/kernels/_triton/test_my_kernel.py
class TestMyKernel(unittest.TestCase):
# Test implementation
pass- Update documentation
# Run all tests
pytest test/
# Platform-specific tests
pytest test/kernels/_xla/ # XLA implementations
pytest test/kernels/_triton/ # Triton implementations
pytest test/kernels/_pallas/ # Pallas implementations
# Specific test patterns
pytest -k "flash_attention"
pytest --verbose --failfast
# Module operations tests
pytest test/modules/operations- Unit Tests: Individual component testing
- Integration Tests: End-to-end workflows
- Comparison Tests: Cross-backend consistency
- Performance Tests: Regression detection
Run benchmarks to compare performance across backends:
# General attention benchmarks
python benchmarks/benchmark_attention.py
# Flash attention benchmarks
python benchmarks/benchmark_flash_attention.py
# Ragged page attention benchmarks
python benchmarks/benchmark_ragged_page_attention_v3.pyWe welcome contributions!
- TPU/Pallas implementations for existing algorithms
- CUDA native kernels for maximum performance
- New attention mechanisms from recent papers
- Performance optimizations and kernel fusion
- Documentation and examples
- Fork the repository
- Create a feature branch
- Implement your changes with tests
- Ensure all tests pass
- Submit a pull request
Comprehensive documentation available at ejkernel.readthedocs.io
- API Reference: Complete API documentation
- Tutorials: Step-by-step guides
- Architecture: Design documentation
- Benchmarks: Performance analysis
If you use ejKernel in your research, please cite:
@software{ejkernel2025,
author = {Erfan Zare Chavoshi},
title = {ejKernel: High-Performance JAX Kernels for Deep Learning},
year = {2025},
url = {https://github.com/erfanzar/ejkernel},
note = {Production-grade kernel library with multi-backend support}
}ejKernel is licensed under the Apache License 2.0. See LICENSE for details.
ejKernel builds upon excellent work from:
- JAX - Composable transformations of Python+NumPy programs
- Triton - GPU kernel programming language
- Pallas - JAX kernel language
- Flash Attention - Memory-efficient attention
- EasyDeL - Parent framework for JAX deep learning
- GitHub Issues: Bug reports and feature requests
- Discussions: Community forum
- Email: Erfanzare810@gmail.com
ejKernel - Production-grade kernels for JAX deep learning