Skip to content
/ ike Public

A DeepSpeed-based framework for distributed training and inference of language models.

License

Notifications You must be signed in to change notification settings

SakanaAI/ike

Repository files navigation

ike

License: MIT Python 3.10+

A DeepSpeed-based framework for distributed training and inference of language models.

Features

  • DeepSpeed Integration: Full support for ZeRO stages 0-3 for memory-efficient distributed training
  • PEFT/LoRA Support: Built-in support for parameter-efficient fine-tuning methods
  • Multi-GPU Training: Seamless distributed training across multiple GPUs
  • Flexible Pipeline Architecture: Customizable data processing, forward functions, and model building
  • Logging Integration: TensorBoard and Weights & Biases support out of the box
  • Checkpoint Management: Automatic model saving with configurable intervals and metric monitoring

Table of Contents

Prerequisites

  • Python 3.10 or higher
  • CUDA-capable GPU (for training)
  • PyTorch 2.0+

Installation

Install from the repository:

git clone https://github.com/SakanaAI/ike.git
cd ike
pip install -e .

For development (includes testing and linting tools):

pip install -e ".[dev]"

For flash attention support:

pip install -e ".[flash]"

Quick Start

1. Create a data processor

from ike import DataProcessor

class MyDataProcessor(DataProcessor):
    def line2data(self, line: dict) -> dict:
        # Process each data sample
        text = line["text"]
        tokens = self.tokenizer.encode(text)
        return {"input_ids": tokens}

2. Define forward functions

def train_forward_step(step, accum_idx, model, tokenizer, batch_data, config):
    outputs = model(**batch_data)
    loss = outputs.loss
    return loss, {}, {"loss": loss.item()}

3. Run training

from ike import TrainingPipeline, get_arguments, load_data_from_jsonl

config = get_arguments()
pipeline = TrainingPipeline(config, world_size, local_rank, global_rank)
pipeline.run(
    load_data_from_filepath_fn=load_data_from_jsonl,
    data_processor_classes=[MyDataProcessor],
    train_forward_step_fn=train_forward_step,
    valid_forward_step_fn=valid_forward_step,
)

4. Launch with DeepSpeed

deepspeed --include localhost:0,1 --master_port 12300 training.py \
    -c cfgs/training.cfg cfgs/model.cfg \
    --save_log --save_model

Architecture

The framework follows a Task → Pipeline → Modules pattern:

Task
  ↓  instantiates (by providing configurations and customized modules)
Pipeline
  ↓
Configures
    pre-implemented modules (optimizer creation, LR scheduler, DeepSpeed, logging)
    with the input configurations
Fills in
    customizable modules (data loader/processor/source, model creation, forward function, metrics)
    with the input customizable modules
Executes
    the modules

Core Components

  1. Pipelines (TrainingPipeline, InferencePipeline): High-level abstractions that orchestrate distributed training/inference, data loading, model management, and logging.

  2. Data Processing (DataProcessor, BasicDataSource): Flexible data loading and processing with support for JSONL files and HuggingFace datasets.

  3. Configuration (get_arguments, get_inference_arguments): YAML-based configuration with command-line override support.

Pipeline Customization

The TrainingPipeline.run() method accepts these customizable modules:

Module Required Description
load_data_from_filepath_fn Yes Function to load data from file paths
data_processor_classes Yes List of DataProcessor subclasses
train_forward_step_fn Yes Training forward pass implementation
valid_forward_step_fn Yes Validation forward pass implementation
build_tokenizer_fn No Custom tokenizer builder
build_model_fn No Custom model builder
build_optimizer_fn No Custom optimizer builder
build_lr_scheduler_fn No Custom LR scheduler builder

Examples

See the examples/ directory for complete working examples:

  • examples/lm/: Language model fine-tuning example
  • examples/gsm8k/: Supervised fine-tuning on GSM8K math dataset

Each example includes:

  • Task-specific data processors
  • Training and evaluation scripts
  • Configuration files in cfgs/

Configuration

Configuration is managed through YAML files and command-line arguments. Key argument groups:

Model Arguments

  • --pretrained_model_dir: Path to pretrained model
  • --tokenizer_dir: Path to tokenizer (defaults to model dir)
  • --attn_implementation: Attention implementation (eager, flash_attention_2)

Training Arguments

  • --global_batch_size: Total batch size across all GPUs
  • --micro_batch_size: Batch size per GPU per step
  • --n_epochs: Number of training epochs
  • --peak_lr: Peak learning rate

DeepSpeed Arguments

  • --zero_stage: ZeRO optimization stage (0, 1, 2, or 3)
  • --bf16: Enable bfloat16 training
  • --activation_checkpointing_layers: Number of layers for activation checkpointing

PEFT/LoRA Arguments

  • --peft_type: PEFT method (LORA)
  • --peft_lora_r: LoRA rank
  • --peft_lora_alpha: LoRA alpha parameter
  • --peft_lora_target_modules: Modules to apply LoRA to

See README_args.md for the complete argument reference.

Troubleshooting

Common Issues

Out of Memory (OOM)

  • Reduce --micro_batch_size
  • Increase --zero_stage (try 2 or 3)
  • Enable --activation_checkpointing_layers

Slow Training

  • Ensure --bf16 is enabled
  • Try --attn_implementation flash_attention_2
  • Adjust --gradient_accumulation_steps

Data Loading Issues

  • Use --debug_mode to disable multiprocessing for easier debugging
  • Check --data_processor_chunksize for memory issues

License

This project is licensed under the MIT License - see the LICENSE file for details.

About

A DeepSpeed-based framework for distributed training and inference of language models.

Resources

License

Contributing

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages