A DeepSpeed-based framework for distributed training and inference of language models.
- DeepSpeed Integration: Full support for ZeRO stages 0-3 for memory-efficient distributed training
- PEFT/LoRA Support: Built-in support for parameter-efficient fine-tuning methods
- Multi-GPU Training: Seamless distributed training across multiple GPUs
- Flexible Pipeline Architecture: Customizable data processing, forward functions, and model building
- Logging Integration: TensorBoard and Weights & Biases support out of the box
- Checkpoint Management: Automatic model saving with configurable intervals and metric monitoring
- Python 3.10 or higher
- CUDA-capable GPU (for training)
- PyTorch 2.0+
Install from the repository:
git clone https://github.com/SakanaAI/ike.git
cd ike
pip install -e .For development (includes testing and linting tools):
pip install -e ".[dev]"For flash attention support:
pip install -e ".[flash]"from ike import DataProcessor
class MyDataProcessor(DataProcessor):
def line2data(self, line: dict) -> dict:
# Process each data sample
text = line["text"]
tokens = self.tokenizer.encode(text)
return {"input_ids": tokens}def train_forward_step(step, accum_idx, model, tokenizer, batch_data, config):
outputs = model(**batch_data)
loss = outputs.loss
return loss, {}, {"loss": loss.item()}from ike import TrainingPipeline, get_arguments, load_data_from_jsonl
config = get_arguments()
pipeline = TrainingPipeline(config, world_size, local_rank, global_rank)
pipeline.run(
load_data_from_filepath_fn=load_data_from_jsonl,
data_processor_classes=[MyDataProcessor],
train_forward_step_fn=train_forward_step,
valid_forward_step_fn=valid_forward_step,
)deepspeed --include localhost:0,1 --master_port 12300 training.py \
-c cfgs/training.cfg cfgs/model.cfg \
--save_log --save_modelThe framework follows a Task → Pipeline → Modules pattern:
Task
↓ instantiates (by providing configurations and customized modules)
Pipeline
↓
Configures
pre-implemented modules (optimizer creation, LR scheduler, DeepSpeed, logging)
with the input configurations
Fills in
customizable modules (data loader/processor/source, model creation, forward function, metrics)
with the input customizable modules
Executes
the modules
-
Pipelines (
TrainingPipeline,InferencePipeline): High-level abstractions that orchestrate distributed training/inference, data loading, model management, and logging. -
Data Processing (
DataProcessor,BasicDataSource): Flexible data loading and processing with support for JSONL files and HuggingFace datasets. -
Configuration (
get_arguments,get_inference_arguments): YAML-based configuration with command-line override support.
The TrainingPipeline.run() method accepts these customizable modules:
| Module | Required | Description |
|---|---|---|
load_data_from_filepath_fn |
Yes | Function to load data from file paths |
data_processor_classes |
Yes | List of DataProcessor subclasses |
train_forward_step_fn |
Yes | Training forward pass implementation |
valid_forward_step_fn |
Yes | Validation forward pass implementation |
build_tokenizer_fn |
No | Custom tokenizer builder |
build_model_fn |
No | Custom model builder |
build_optimizer_fn |
No | Custom optimizer builder |
build_lr_scheduler_fn |
No | Custom LR scheduler builder |
See the examples/ directory for complete working examples:
examples/lm/: Language model fine-tuning exampleexamples/gsm8k/: Supervised fine-tuning on GSM8K math dataset
Each example includes:
- Task-specific data processors
- Training and evaluation scripts
- Configuration files in
cfgs/
Configuration is managed through YAML files and command-line arguments. Key argument groups:
--pretrained_model_dir: Path to pretrained model--tokenizer_dir: Path to tokenizer (defaults to model dir)--attn_implementation: Attention implementation (eager,flash_attention_2)
--global_batch_size: Total batch size across all GPUs--micro_batch_size: Batch size per GPU per step--n_epochs: Number of training epochs--peak_lr: Peak learning rate
--zero_stage: ZeRO optimization stage (0, 1, 2, or 3)--bf16: Enable bfloat16 training--activation_checkpointing_layers: Number of layers for activation checkpointing
--peft_type: PEFT method (LORA)--peft_lora_r: LoRA rank--peft_lora_alpha: LoRA alpha parameter--peft_lora_target_modules: Modules to apply LoRA to
See README_args.md for the complete argument reference.
Out of Memory (OOM)
- Reduce
--micro_batch_size - Increase
--zero_stage(try 2 or 3) - Enable
--activation_checkpointing_layers
Slow Training
- Ensure
--bf16is enabled - Try
--attn_implementation flash_attention_2 - Adjust
--gradient_accumulation_steps
Data Loading Issues
- Use
--debug_modeto disable multiprocessing for easier debugging - Check
--data_processor_chunksizefor memory issues
This project is licensed under the MIT License - see the LICENSE file for details.