Skip to content

erfanzar/ejkernel

Repository files navigation

ejKernel: High-Performance JAX Kernels for Deep Learning

"The best optimization is the one you don't have to think about."

License Python 3.11+ JAX Documentation

ejKernel is a production-grade kernel library for JAX that provides highly optimized implementations of deep learning operations with automatic multi-backend support. The library features a sophisticated configuration management system with autotuning, comprehensive type safety, and seamless execution across GPUs, TPUs, and CPUs.

Note

eJkernel contains no AI-generated code. All kernels, modules, and core logic are manually designed and implemented by human developers. AI tooling (Opus 4.5) is used exclusively for documentation, which may therefore contain minor inaccuracies. There is no “vibe coding” or automated code generation anywhere in the codebase.

Table of Contents

Key Features

Intelligent Kernel Management

  • 7-Tier Configuration System: Override → Overlay → Memory Cache → Persistent Cache → Autotune → Heuristics → Error
  • Automatic Platform Detection: Seamlessly selects optimal implementation based on hardware
  • Priority-Based Registry: Multi-backend support with intelligent fallback mechanisms
  • Device Fingerprinting: Hardware-specific configuration caching for optimal performance

State-of-the-Art Operations

  • 30+ Deep Learning Operations: Flash Attention v2, Flash MLA, Ring Attention, Page Attention, Block Sparse, GLA, Lightning, Quantized MatMul, State Space Models (Mamba), RWKV (v4/v6/v7), and more
  • Memory Efficiency: Custom VJP implementations with O(N) memory complexity for attention
  • Distributed Support: Full shard_map integration for model and data parallelism
  • Mixed Precision: Comprehensive dtype support with automatic gradient conversion

Production-Ready Infrastructure

  • Type Safety: Full jaxtyping annotations with runtime validation via beartype
  • Comprehensive Testing: Cross-backend validation, performance benchmarks, integration tests
  • Atomic Persistence: Thread-safe configuration storage with automatic optimization
  • Profiling Integration: Built-in support for JAX profiling and performance monitoring

Installation

Basic Installation

pip install ejkernel

Platform-Specific Installation

# GPU Support (CUDA)
pip install ejkernel[cuda]

# TPU Support
pip install ejkernel[tpu]

# Development Installation
git clone https://github.com/erfanzar/ejkernel.git
cd ejkernel
pip install -e ".[dev]"

Dependencies

  • Python 3.11-3.13
  • JAX >= 0.9.0
  • Triton == 3.6.0 (for GPU)
  • nvidia-cutlass-dsl >= 4.4.0 (optional, for CuTe DSL kernels)
  • jax-tvm-ffi == 0.1.2 (optional, for CuTe TVM-FFI primitive path)
  • jaxtyping >= 0.3.2
  • beartype >= 0.22.2
  • pydantic >= 2.11.10

Quick Start

Simple API with Automatic Optimization

import jax.numpy as jnp
from ejkernel.modules import flash_attention

# Basic usage - automatic configuration selection
output = flash_attention(
    query, key, value,
    causal=True,
    dropout_prob=0.1
)

# With advanced features
output = flash_attention(
    query, key, value,
    causal=True,
    sliding_window=128,        # Local attention window
    logits_soft_cap=30.0,      # Gemma-2 style soft capping
    attention_mask=mask,       # Custom attention pattern
)

Custom Configuration

from ejkernel.modules import FlashAttentionConfig
from ejkernel.ops.utils.datacarrier import FwdParams, BwdParams

# Create optimized configuration
config = FlashAttentionConfig(
    fwd_params=FwdParams(
        q_blocksize=256,
        kv_blocksize=256,
        num_warps=8,
        num_stages=2
    ),
    bwd_params=BwdParams(
        q_blocksize=128,
        kv_blocksize=128,
        num_warps=4
    ),
    platform="triton",  # Force specific backend
    backend="gpu"
)

output = flash_attention(query, key, value, cfg=config)

Direct Kernel Registry Access

from ejkernel import kernel_registry, Platform, Backend

# Get specific implementation
kernel = kernel_registry.get(
    algorithm="flash_attention",
    platform=Platform.TRITON,
    backend=Backend.GPU
)

# Direct execution
output = kernel(query, key, value, causal=True)

Distributed Execution

import jax
from jax.sharding import Mesh, PartitionSpec as P
from ejkernel.modules import flash_attention

# Setup mesh for distributed execution
devices = jax.devices()
mesh = Mesh(devices, axis_names=("data", "model"))

# Run distributed attention
output = flash_attention(
    query, key, value,
    causal=True,
    mesh=mesh,
    in_specs=(P("data", None), P("data", None), P("data", None)),
    out_specs=P("data", None)
)

Architecture Overview

System Design

ejKernel employs a sophisticated layered architecture that separates concerns while maintaining high performance:

┌─────────────────────────────────────────────────────┐
│ Public API (modules/) │
│ Simple functions with sensible defaults │
├─────────────────────────────────────────────────────┤
│ Operations Layer (ops/) │
│ Configuration management, autotuning, caching │
├─────────────────────────────────────────────────────┤
│ Kernel Registry (kernels/) │
│ Platform routing, signature validation │
├─────────────────────────────────────────────────────┤
│ Backend Implementations (kernels/\_\*) │
│ Triton, CuTe, Pallas, XLA, CUDA kernels │
└─────────────────────────────────────────────────────┘

Core Components

Kernel Registry

The registry provides automatic platform-specific kernel selection:

@kernel_registry.register("my_operation", Platform.TRITON, Backend.GPU, priority=100)
def my_operation_gpu(x, y):
    # GPU-optimized implementation
    pass

@kernel_registry.register("my_operation", Platform.XLA, Backend.ANY, priority=50)
def my_operation_fallback(x, y):
    # Universal fallback
    pass

# Automatic selection based on available hardware
impl = kernel_registry.get("my_operation")

Configuration Management

Multi-tier configuration system with intelligent fallback:

class ConfigSelectorChain:
    """
    Selection hierarchy:
    1. Override - Explicit user configuration
    2. Overlay - Temporary context overrides
    3. Memory Cache - In-memory lookup
    4. Persistent Cache - Disk-based storage
    5. Autotune - Performance benchmarking
    6. Heuristics - Intelligent defaults
    7. Error - Clear failure message
    """

Custom VJP System

All performance-critical kernels implement memory-efficient gradients:

@jax.custom_vjp
def kernel_with_custom_grad(inputs):
    return forward(inputs)

def kernel_fwd(inputs):
    output, residuals = forward_with_residuals(inputs)
    return output, residuals

def kernel_bwd(residuals, grad_output):
    return efficient_backward(residuals, grad_output)

kernel_with_custom_grad.defvjp(kernel_fwd, kernel_bwd)

Supported Operations

Attention Mechanisms

Algorithm Description Memory Key Features
Flash Attention v2 Memory-efficient exact attention O(N) Causal masking, dropout, sliding windows, soft capping
Ring Attention Distributed sequence parallelism O(N/P) Ultra-long sequences, communication overlap, XLA single-device fallback
Page Attention KV-cache optimized inference O(N) Block-wise memory, continuous batching
Block Sparse Attention Configurable sparse patterns O(N√N) Local+global, custom patterns
GLA Gated Linear Attention O(N) Linear complexity, gated updates
Lightning Attention Layer-dependent decay O(N) Exponential moving average
MLA Multi-head Latent Attention O(N) Compressed KV representation
Ragged Page Attention v2 Variable-length paged attention O(N) Ragged sequences with page caching
Ragged Page Attention v3 Enhanced ragged page attention O(N) Attention sinks support, improved handling
Ragged Decode Attention Variable-length decoding O(N) Efficient batched inference
Kernel Delta Attention Delta-rule linear attention O(N) Linear complexity, delta updates, decay control
Unified Attention vLLM-style paged attention O(N) Segmented 3D decode kernel
Prefill Page Attention Page attention prefill phase O(N) Separate prefill handling
Decode Attention Single-token decode attention O(N) Optimized single-step decoding
Chunked Prefill Paged Decode Combined prefill + decode O(N) Chunked prefill with paged KV cache decode
Flash MLA Multi-head Latent Attention O(N) Low-rank KV compression, memory-efficient inference
Scaled Dot-Product Attention Standard attention O(N²) Basic reference implementation

Recurrent Linear Attention (RWKV)

Operation Description Key Features
RWKV-4 Time-mix recurrence Numerically stable (α,β,ε) state, O(N) memory
RWKV-6 Multi-head linear attention Variable-length packing, reverse mode, O(N) memory
RWKV-7 DPLR (Diagonal + Low-Rank) recurrence (a,b) parameterization, state-space inspired
RWKV-7 Mul Multiplicative RWKV-7 variant (kk,a) reparameterization for optimized kernels

Other Operations

Operation Description Use Case
Grouped MatMul Efficient batched matrix operations Expert models, MoE
Grouped MatMul v2 Enhanced with shard_map support Distributed expert models
Mean Pooling Variable-length sequence aggregation Sentence embeddings
Recurrent Optimized RNN/LSTM/GRU operations Sequential modeling
Native Sparse Block-sparse matrix computations Sparse attention patterns
Quantized MatMul Multi-mode quantized matmul (affine, NF4, MXFP4/8, NVFP4/8) Low-bit inference

State Space Models

Operation Description Key Features
State Space v1 Mamba1-style SSM 2D A matrix, separate dt_proj, custom VJP for memory efficiency
State Space v2 Mamba2-style SSM Per-head scalar A, n_groups for parameter grouping, optional gated RMSNorm

Platform Support Matrix

Operation Triton (GPU) CUTE (GPU) CUDA (GPU) Pallas (TPU) XLA (Universal)
Flash Attention v2
Flash MLA - - -
Ring Attention - -
Page Attention - -
Block Sparse Attention -
Decode Attention - - -
Chunked Prefill Paged Decode - -
Ragged Page Attention v2 - -
Ragged Page Attention v3 -
Ragged Decode Attention - -
GLA - - -
Lightning Attention - - -
Recurrent - - -
Mean Pooling - - -
Grouped MatMul - - -
Grouped MatMul v2 - - - -
Native Sparse Attention - - -
Quantized MatMul
Kernel Delta Attention - - - -
Unified Attention -
Prefill Page Attention - - -
Scaled Dot-Product Attention - - - -
State Space v1 - - - -
State Space v2 - - - -
RWKV-4 - - -
RWKV-6 - - -
RWKV-7 - - -
RWKV-7 Mul - - -

✅ = Production ready | - = Not available

* CuTe backend uses TVM-FFI primitive path with fused kernels. * Quantized MatMul on TPU uses hybrid dispatch (packed Pallas / predecode / XLA fallback). * Distributed matmul ops (all_gather_matmul, reduce_scatter_matmul) intentionally do not perform runtime fallback between distributed backends; choose platform/cfg.platform explicitly.

Advanced Usage

Page Attention for KV-Cache Inference

from ejkernel.modules import page_attention, PageAttentionConfig

# Configure paged attention for inference
config = PageAttentionConfig(
    platform="auto",
    backend="gpu"
)

output = page_attention(
    query=q,
    key_cache=k_cache,
    value_cache=v_cache,
    block_table=block_table,
    cache_seqlens=cache_seqlens,
    cfg=config
)

Ragged Page Attention for Variable-Length Batches

from ejkernel.modules import ragged_page_attention_v3, RaggedPageAttentionv3Config

# For variable-length sequences with attention sinks
config = RaggedPageAttentionv3Config(
    platform="pallas",
    backend="tpu"
)

output = ragged_page_attention_v3(
    query=q,
    key_pages=k_pages,
    value_pages=v_pages,
    lengths=seq_lengths,
    page_indices=page_indices,
    cfg=config
)

Performance Optimization

# Force autotuning for optimal configuration
import os
os.environ["EJKERNEL_AUTOTUNE_POLICY"] = "autotune"
os.environ["EJKERNEL_LOG_AUTOTUNE"] = "1"

# Enable profiling
os.environ["EJKERNEL_OPS_STAMP"] = "json"  # Detailed metadata
os.environ["EJKERNEL_OPS_RECORD"] = "1"    # Record invocations

Custom Kernel Development

from ejkernel.ops.core import Kernel
from ejkernel.modules.operations.configs import BaseOperationConfig
from dataclasses import dataclass

@dataclass
class MyConfig(BaseOperationConfig):
    param1: int = 128
    param2: float = 0.1

class MyKernel(Kernel[MyConfig, Array]):
    def __init__(self):
        super().__init__(op_id="my_kernel")

    def run(self, x, cfg: MyConfig):
        impl = kernel_registry.get("my_kernel", cfg.platform)
        return impl(x, param1=cfg.param1, param2=cfg.param2)

    def heuristic_cfg(self, inv):
        # Return default configuration
        return MyConfig(param1=256)

    def candidate_cfgs(self, inv):
        # Return autotuning candidates
        return [MyConfig(param1=p) for p in [64, 128, 256]]

Integration with Flax Models

import flax.linen as nn
from ejkernel.modules import flash_attention

class TransformerBlock(nn.Module):
    num_heads: int = 8
    head_dim: int = 64

    @nn.compact
    def __call__(self, x, mask=None):
        # Project to Q, K, V
        q = nn.Dense(self.num_heads * self.head_dim)(x)
        k = nn.Dense(self.num_heads * self.head_dim)(x)
        v = nn.Dense(self.num_heads * self.head_dim)(x)

        # Reshape for attention
        shape = (x.shape[0], x.shape[1], self.num_heads, self.head_dim)
        q, k, v = map(lambda t: t.reshape(shape), (q, k, v))

        # Apply ejKernel Flash Attention
        attn_output = flash_attention(
            q, k, v,
            causal=True,
            attention_mask=mask
        )

        # Project output
        return nn.Dense(x.shape[-1])(attn_output.reshape(x.shape))

Development

Setting Up Development Environment

# Clone repository
git clone https://github.com/erfanzar/ejkernel.git
cd ejkernel

# Create virtual environment
python -m venv .venv
source .venv/bin/activate  # On Windows: .venv\Scripts\activate

# Install in development mode
pip install -e ".[dev]"

# Install pre-commit hooks
pre-commit install

Code Style

The project uses:

  • black for code formatting (line length: 121)
  • ruff for linting
  • mypy/pyright for type checking
  • pre-commit for automated checks

Adding New Kernels

  1. Implement the kernel in appropriate backend directory:
# ejkernel/kernels/_triton/my_kernel/_interface.py
@kernel_registry.register("my_kernel", Platform.TRITON, Backend.GPU)
def my_kernel_triton(x, config):
    # Implementation
    pass
  1. Create module wrapper:
# ejkernel/modules/operations/my_kernel.py
class MyKernel(Kernel[MyKernelConfig, Array]):
    # Module implementation
    pass
  1. Add tests:
# test/kernels/_triton/test_my_kernel.py
class TestMyKernel(unittest.TestCase):
    # Test implementation
    pass
  1. Update documentation

Testing

Running Tests

# Run all tests
pytest test/

# Platform-specific tests
pytest test/kernels/_xla/          # XLA implementations
pytest test/kernels/_triton/       # Triton implementations
pytest test/kernels/_pallas/       # Pallas implementations

# Specific test patterns
pytest -k "flash_attention"
pytest --verbose --failfast

# Module operations tests
pytest test/modules/operations

Test Categories

  • Unit Tests: Individual component testing
  • Integration Tests: End-to-end workflows
  • Comparison Tests: Cross-backend consistency
  • Performance Tests: Regression detection

Benchmarking

Run benchmarks to compare performance across backends:

# General attention benchmarks
python benchmarks/benchmark_attention.py

# Flash attention benchmarks
python benchmarks/benchmark_flash_attention.py

# Ragged page attention benchmarks
python benchmarks/benchmark_ragged_page_attention_v3.py

Contributing

We welcome contributions!

Priority Areas

  • TPU/Pallas implementations for existing algorithms
  • CUDA native kernels for maximum performance
  • New attention mechanisms from recent papers
  • Performance optimizations and kernel fusion
  • Documentation and examples

Contribution Process

  1. Fork the repository
  2. Create a feature branch
  3. Implement your changes with tests
  4. Ensure all tests pass
  5. Submit a pull request

Documentation

Comprehensive documentation available at ejkernel.readthedocs.io

Citation

If you use ejKernel in your research, please cite:

@software{ejkernel2025,
  author = {Erfan Zare Chavoshi},
  title = {ejKernel: High-Performance JAX Kernels for Deep Learning},
  year = {2025},
  url = {https://github.com/erfanzar/ejkernel},
  note = {Production-grade kernel library with multi-backend support}
}

License

ejKernel is licensed under the Apache License 2.0. See LICENSE for details.

Acknowledgments

ejKernel builds upon excellent work from:

  • JAX - Composable transformations of Python+NumPy programs
  • Triton - GPU kernel programming language
  • Pallas - JAX kernel language
  • Flash Attention - Memory-efficient attention
  • EasyDeL - Parent framework for JAX deep learning

Community


ejKernel - Production-grade kernels for JAX deep learning

About

easydel jax kernels writen in triton for gpus and pallas for tpus

Resources

License

Contributing

Stars

Watchers

Forks

Packages

 
 
 

Contributors