ISSUE TITLE:
[Bug] jax[cuda] hardcoded in optional dependencies breaks install on CPU-only machines
ISSUE BODY:
Bug Description
The flax and xreg optional dependency groups in pyproject.toml hardcode jax[cuda]:
[project.optional-dependencies]
flax = [
"flax",
"optax",
"einshape",
"orbax-checkpoint",
"jaxtyping",
"jax[cuda]" # <-- forces CUDA
]
xreg = [
"jax[cuda]", # <-- forces CUDA
"scikit-learn",
]
This means any user who runs:
pip install timesfm[flax]
# or
pip install timesfm[xreg]
...on a CPU-only machine (local laptop, CI runner, cloud VM without GPU) gets a
broken or errored installation, because jax[cuda] requires NVIDIA CUDA drivers.
Steps to Reproduce
# On any CPU-only machine:
pip install timesfm[flax]
Expected: installs successfully and runs on CPU
Actual: pip fails or JAX fails at runtime with CUDA errors
Error Example
RuntimeError: Unable to initialize backend 'cuda': ...
# or during pip install:
ERROR: Could not find a version of jax[cuda] that satisfies the requirement
Environment
- CPU-only machine (no NVIDIA GPU / no CUDA drivers)
- Python 3.10+
- pip install timesfm[flax] or timesfm[xreg]
Proposed Fix
Split each GPU extra into a separate -cuda variant, and make the base extras
use plain jax (which defaults to CPU and works everywhere):
[project.optional-dependencies]
# CPU-compatible (works on any machine)
flax = [
"flax",
"optax",
"einshape",
"orbax-checkpoint",
"jaxtyping",
"jax",
]
# GPU only (CUDA 12)
flax-cuda = [
"flax",
"optax",
"einshape",
"orbax-checkpoint",
"jaxtyping",
"jax[cuda12]",
]
# CPU-compatible
xreg = [
"jax",
"scikit-learn",
]
# GPU only (CUDA 12)
xreg-cuda = [
"jax[cuda12]",
"scikit-learn",
]
Users would then install as:
# CPU
pip install timesfm[flax]
pip install timesfm[xreg]
# GPU
pip install timesfm[flax-cuda]
pip install timesfm[xreg-cuda]
Why This Matters
TimesFM is widely used for research and experimentation — many users run it on
laptops or CPU-based cloud instances. Forcing a CUDA install silently blocks
all of them from using the Flax backend or XReg covariates, with confusing
error messages that don't point to this root cause.
Note: jax[cuda] (without version suffix) is also deprecated upstream in
favor of jax[cuda12]. The fix above uses the current recommended name.
References
ISSUE TITLE:
[Bug]
jax[cuda]hardcoded in optional dependencies breaks install on CPU-only machinesISSUE BODY:
Bug Description
The
flaxandxregoptional dependency groups inpyproject.tomlhardcodejax[cuda]:This means any user who runs:
pip install timesfm[flax] # or pip install timesfm[xreg]...on a CPU-only machine (local laptop, CI runner, cloud VM without GPU) gets a
broken or errored installation, because
jax[cuda]requires NVIDIA CUDA drivers.Steps to Reproduce
# On any CPU-only machine: pip install timesfm[flax]Expected: installs successfully and runs on CPU
Actual: pip fails or JAX fails at runtime with CUDA errors
Error Example
Environment
Proposed Fix
Split each GPU extra into a separate
-cudavariant, and make the base extrasuse plain
jax(which defaults to CPU and works everywhere):Users would then install as:
Why This Matters
TimesFM is widely used for research and experimentation — many users run it on
laptops or CPU-based cloud instances. Forcing a CUDA install silently blocks
all of them from using the Flax backend or XReg covariates, with confusing
error messages that don't point to this root cause.
Note:
jax[cuda](without version suffix) is also deprecated upstream infavor of
jax[cuda12]. The fix above uses the current recommended name.References