Skip to content

Minimal and scalable research codebase in JAX, designed for rapid iteration on frontier research in LLM and other autoregressive models.

License

Notifications You must be signed in to change notification settings

google-deepmind/simply

Repository files navigation

Simply: Minimal Code for Frontier LLM Research in JAX

Simply is a minimal and scalable research codebase in JAX, designed for rapid iteration on frontier research in LLM and other autoregressive models.

  • Quick to fork and hack for fast iteration. You should be able to implement your research ideas (e.g., new architecture, optimizer, training loss, etc) in a few hours.
  • Minimal abstractions and dependencies for a simple and self-contained codebase. Learn Jax (if you haven't), and you are ready to read and hack the code.
  • That's it, simply get started with hacking now :)

Getting started

Example commands

Local test for debug

EXP=simply_local_test_1; rm -rf /tmp/${EXP}; python -m simply.main --experiment_config lm_test --experiment_dir /tmp/${EXP} --alsologtostderr

Or if you want to debug by printing arrays like normal python code, you can disable jit and use_scan using the command below.

export JAX_DISABLE_JIT=True; EXP=simply_local_test_1; rm -rf /tmp/${EXP}; python -m simply.main --experiment_config lm_no_scan_test --experiment_dir /tmp/${EXP} --alsologtostderr

RL training with Gemma 2B on GSM8K

Before running the example, you need to download the model checkpoints and datasets.

EXP=gemma2b_gsm8k_rl; rm -rf /tmp/${EXP}; python -m simply.main --experiment_config gemma2_2b_gsm8k_0shot_rl --experiment_dir /tmp/${EXP} --alsologtostderr

Small-scale pretraining (2e16 FLOPs)

EXP=pretrain_small; rm -rf /tmp/${EXP}; python -m simply.main --experiment_config flops2e16_tfm15m_c4_l2048 --experiment_dir /tmp/${EXP} --alsologtostderr

Mesh shape configuration

The above examples use default mesh shapes: data parallelism (FSDP) for training, and model (tensor) parallelism for decoding. You can customize with:

  • --mesh_shape: Training mesh shape (e.g., --mesh_shape 1,8,1 for replica, data, model axes)
  • --decoding_mesh_shape: Decoding mesh shape
  • --dcn_mesh_shape: DCN mesh shape for multi-slice deployments

Note: This is for dense models. Guide for MoE models coming soon.

Dependencies

The main dependencies are: Jax for model and training. Orbax for checkpoint management. SeqIO for data pipeline.

Install dependencies:

# JAX installation is environment-specific. See https://docs.jax.dev/en/latest/installation.html
# CPU:
pip install -U jax
# GPU:
pip install -U "jax[cuda13]"
# TPU:
pip install -U "jax[tpu]"
# Other dependencies:
pip install -r requirements.txt

Setup Model Checkpoints and Datasets

Download datasets and model checkpoints in format supported by Simply from HuggingFace:

# Install huggingface_hub
pip install huggingface_hub

# Download both models and datasets
python setup/setup_assets.py

# Or download only models/datasets
python setup/setup_assets.py --models-only
python setup/setup_assets.py --datasets-only

This will download models to ~/.cache/simply/models/ and datasets to ~/.cache/simply/datasets/. You can customize locations with --models-dir and --datasets-dir flags, or set environment variables SIMPLY_MODELS and SIMPLY_DATASETS. (Currently we only included a few datasets and models for testing, and will add more soon.)

The checkpoints are created with simply.tools.hf_to_orbax to convert HuggingFace checkpoints to Orbax. We have already converted some Qwen checkpoints for download through setup_assets.py, and you can follow the example below to convert more checkpoints.

# Example command for downloading and converting the Qwen3-0.6B checkpoint
name=Qwen3-0.6B
format=Qwen2Format
hf download Qwen/${name} --local-dir ${HF_DIR}${name}
python -m simply.tools.hf_to_orbax \
    --input_path=${HF_DIR}${name}/ \
    --output_path=${HF_DIR}${name}/ORBAX/ \
    --format=${format}
# Remove the safetensors to save space.
rm ${HF_DIR}${name}/*safetensors*

Citation

If you find Simply helpful, please cite the following BibTeX:

@misc{Liang2025Simply,
  author       = {Chen Liang and Da Huang and Chengrun Yang and Xiaomeng Yang and Andrew Li and Xinchen Yan and {Simply Contributors}},
  title        = {{Simply: an experiment to accelerate and automate AI research}},
  year         = {2025},
  howpublished = {GitHub repository},
  url          = {https://github.com/google-deepmind/simply}
}

Contributors list: Alex Zhai, Xingjian Zhang, Jiaxi Tang, Lizhang Chen, Ran Tian

License

Copyright 2025 Google LLC

All software is licensed under the Apache License, Version 2.0 (Apache 2.0); you may not use this file except in compliance with the Apache 2.0 license. You may obtain a copy of the Apache 2.0 license at: https://www.apache.org/licenses/LICENSE-2.0

All other materials are licensed under the Creative Commons Attribution 4.0 International License (CC-BY). You may obtain a copy of the CC-BY license at: https://creativecommons.org/licenses/by/4.0/legalcode

Unless required by applicable law or agreed to in writing, all software and materials distributed here under the Apache 2.0 or CC-BY licenses are distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the licenses for the specific language governing permissions and limitations under those licenses.

This is not an official Google product.

About

Minimal and scalable research codebase in JAX, designed for rapid iteration on frontier research in LLM and other autoregressive models.

Resources

License

Contributing

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Contributors 3

  •  
  •  
  •