-
Notifications
You must be signed in to change notification settings - Fork 303
[Contrib] Add GaLore (Gradient Low-Rank Projection) Optimizer #1541
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
rdyro
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this contribution!
GaLore sounds like a great addition to contrib, but the two main things we'd have to address are:
- GaLore is a high-level optimizer, it'd be good to let it take the core optimizer as an argument instead of hard-coding Adam
- given GaLore is limited to 2D, it'd be good to implement something similar to
_muon.pyto allow it to treat reshaped 2D matrices as 2D
optax/contrib/_galore.py
Outdated
| # Project gradient to low-rank subspace | ||
| low_rank_grad = _project_gradient_left(grad, new_projector) | ||
|
|
||
| # Adam update in low-rank space |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The repo for the paper https://github.com/jiaweizzhao/GaLore mentions that the core update can be any optimizer, making GaLore more of a composed optimizer.
Can you take a look at allowing this implementation to take any optax gradient transformation here instead of hard-coding Adam?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed — I’m refactoring GaLore to accept an inner GradientTransformation operating in the projected space, instead of hard-coding Adam.
optax/contrib/_galore.py
Outdated
| proj_list = [] | ||
|
|
||
| for p in leaves: | ||
| if p.ndim == 2: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think in general we want to be able to allow treating non-2D arrays as 2D matrices for cases like attention projections often stored as (embedding, heads, head_dim), but which really are 2D linear maps.
Can you take a look at how Muon approaches this here _muon.py. Would you be interested in adding something like that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this makes sense. I’m adapting the approach from _muon.py to allow reshaping non-2D parameters into 2D matrices for projection, while preserving the original parameter structure.
optax/contrib/_galore.py
Outdated
|
|
||
| def init_fn(params: base.Params) -> GaLoreState: | ||
| # Use flattening to avoid brittle tuple-unpacking logic | ||
| if not isinstance(rank, int): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the rank validation should be performed above, outside init_fn
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch — I’ll move this validation outside init_fn.
optax/contrib/_galore.py
Outdated
| mu_dtype = utils.canonicalize_dtype(mu_dtype) | ||
|
|
||
| def init_fn(params: base.Params) -> GaLoreState: | ||
| # Use flattening to avoid brittle tuple-unpacking logic |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: this comment is redundant in the context of jax and its pytrees
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed — removing this.
optax/contrib/_galore.py
Outdated
| return vh[:rank, :].T.astype(original_dtype) | ||
|
|
||
|
|
||
| def _project_gradient_left( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
all the _project_.* function seem redundant as they are only used once, could you delete and inline them?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure — I’ll inline these helpers to reduce indirection.
|
Thanks a lot for the detailed and thoughtful feedback — this is very helpful. I agree with the points raised and I’m working on refactoring the implementation to better align with Optax’s design patterns. Planned updates: Refactor GaLore to be optimizer-agnostic by composing it with an inner optax.GradientTransformation instead of hard-coding Adam. Add support for treating non-2D parameters as 2D (e.g. attention projections), following the approach used in _muon.py. Move validation logic (e.g. rank) outside init_fn. Remove redundant comments and inline the projection helpers as suggested. I’ll push these changes incrementally so they’re easier to review. |
Summary of ChangesThis update addresses feedback regarding flexibility, tensor support, and code structure. The implementation now supports generic base optimizers (beyond Adam) and high-dimensional tensor projections (adapted from Muon). 1. Configurable Base Optimizer
2. Support for Non-2D Tensors (Dimension Numbers)
3. Code Refactoring & Cleanup
Verification
|
|
Thanks, this looks great! I think some tests are missing (applying GaLore to non-2D params), but this can be a follow-up. But feel free to push additional tests if you're up for it. |
|
Thanks for the approval @rdyro! I decided to add the non-2D parameter tests now (along with the full test suite for GaLoreDimensionNumbers and 3D reshaping) so the feature is fully verified in this PR. I've also aligned the code style with the Google style guide (2-space indentation, docstrings) to resolve the linting checks. The PR should be good to merge whenever you're ready. Thanks again for the guidance! |
PiperOrigin-RevId: 853480797
This PR implements GaLore (Gradient Low-Rank Projection), a memory-efficient optimizer designed for training large models (LLMs, Diffusion, etc.) on consumer hardware. Based on Zhao et al., 2024, GaLore reduces optimizer state memory by up to 65% while maintaining performance comparable to full-rank AdamW.
This implementation is fully compatible with the Optax ecosystem, including
inject_hyperparams, gradient transformations, and mixed-precision training.🧠 Mathematical Formulation & Design Choices
1. The Core Idea
Standard Adam stores two moment states ($M, V$ ) of the same shape as the weight matrix $W \in \mathbb{R}^{m \times n}$ .
2. Adaptive Single-Sided Projection (Optimized for Memory)
While the paper discusses "Projection" generally, this implementation uses Adaptive Single-Sided Projection to strictly minimize memory usage.
For a weight gradient$G \in \mathbb{R}^{m \times n}$ and rank $r$ :
We can project onto the left subspace ($P \in \mathbb{R}^{m \times r}$ ) or right subspace ($Q \in \mathbb{R}^{n \times r}$ ). The size of the projected moments depends on this choice:
Decision Logic:
We select the strategy that minimizes the total memory footprint.
This logic is implemented statically in
update_single_paramto ensure optimal memory usage for every specific layer shape (e.g., embedding layers vs output heads).3. The Algorithm Step-by-Step
For a chosen projection$P$ (Left case example):
update_proj_gapsteps, compute🛠 Implementation Details
File Structure
optax/contrib/_galore.py: Core implementation.optax/contrib/_galore_test.py: Comprehensive, isolated test suite.optax/contrib/_common_test.py: Integration into standard test harness.Key Technical Limitations & Solutions
A. SVD &
bfloat16CompatibilityProblem: JAX's LAPACK-based SVD implementation usually does not support
bfloat16inputs directly, leading to runtime errors.Solution: We explicitly cast inputs to
float32before SVD computation and cast the resulting projection matrix back to the original dtype (bfloat16). This ensures stability without crashing on TPU/GPU setups using reduced precision.B. Dtype Preservation in Mixed Precision
Problem: Optimizer arithmetic (like bias correction division) can accidentally promote
bfloat16states tofloat32, doubling memory usage silently.Solution: We use explicit casting in
update_fn:Explicit tests (
test_mu_dtype_controls_moment_dtype) verify that updates stay inbfloat16if requested.C. Handling Non-2D Parameters
GaLore is only defined for 2D matrices.
if grad.ndim != 2to dispatch to the correct update logic, ensuring functional correctness for full model architectures.🧪 Testing & Verification
1.
_galore_test.py(New Suite)We added a dedicated test file covering specific edge cases:
2.
_common_test.py(Optax Standard)GaLore passes the standard contrib test suite, including:
test_optimizers: General convergence on Rosenbrock/Parabola.test_gradient_accumulation: Correctness withMultiSteps.test_state_shape_dtype_shard_stability: Tree stricture integrity.3. Known Skip:
inject_hyperparamsWe explicitly skip
test_optimizers_can_be_wrapped_in_inject_hyperparamsfor GaLore.Reason: SVD is numerically sensitive. When
inject_hyperparamstraces the execution graph, small floating-point associativity differences lead to micro-divergences in the SVD result compared to eager execution. These differences are mathematically valid (subspace rotations) but violate the strictrtol=1e-5equality check. This is consistent with other complex optimizers (e.g.,reduce_on_plateau) that skip this check.Mitigation: We added
rankandupdate_proj_gapto thestatic_argsallowlist in the test suite to ensure the wrapping mechanism is theoretically sound.📋 Checklist
bfloat16inputs viafloat32accumulation/computation (Mixed Precision ready).optax.contriband passes common tests.Parameters:
rank(int): Dimension of low-rank subspace (default: 128).update_proj_gap(int): Steps between SVD re-computations (default: 200).scale(float): Learning rate scaling factor (default: 1.0).Fix #1028