Skip to content

[Bug] jax[cuda] hardcoded in optional dependencies breaks install on CPU-only machines #443

@shanujans

Description

@shanujans

ISSUE TITLE:

[Bug] jax[cuda] hardcoded in optional dependencies breaks install on CPU-only machines


ISSUE BODY:

Bug Description

The flax and xreg optional dependency groups in pyproject.toml hardcode jax[cuda]:

[project.optional-dependencies]
flax = [
  "flax",
  "optax",
  "einshape",
  "orbax-checkpoint",
  "jaxtyping",
  "jax[cuda]"   # <-- forces CUDA
]
xreg = [
  "jax[cuda]",  # <-- forces CUDA
  "scikit-learn",
]

This means any user who runs:

pip install timesfm[flax]
# or
pip install timesfm[xreg]

...on a CPU-only machine (local laptop, CI runner, cloud VM without GPU) gets a
broken or errored installation, because jax[cuda] requires NVIDIA CUDA drivers.

Steps to Reproduce

# On any CPU-only machine:
pip install timesfm[flax]

Expected: installs successfully and runs on CPU
Actual: pip fails or JAX fails at runtime with CUDA errors

Error Example

RuntimeError: Unable to initialize backend 'cuda': ...
# or during pip install:
ERROR: Could not find a version of jax[cuda] that satisfies the requirement

Environment

  • CPU-only machine (no NVIDIA GPU / no CUDA drivers)
  • Python 3.10+
  • pip install timesfm[flax] or timesfm[xreg]

Proposed Fix

Split each GPU extra into a separate -cuda variant, and make the base extras
use plain jax (which defaults to CPU and works everywhere):

[project.optional-dependencies]

# CPU-compatible (works on any machine)
flax = [
  "flax",
  "optax",
  "einshape",
  "orbax-checkpoint",
  "jaxtyping",
  "jax",
]

# GPU only (CUDA 12)
flax-cuda = [
  "flax",
  "optax",
  "einshape",
  "orbax-checkpoint",
  "jaxtyping",
  "jax[cuda12]",
]

# CPU-compatible
xreg = [
  "jax",
  "scikit-learn",
]

# GPU only (CUDA 12)
xreg-cuda = [
  "jax[cuda12]",
  "scikit-learn",
]

Users would then install as:

# CPU
pip install timesfm[flax]
pip install timesfm[xreg]

# GPU
pip install timesfm[flax-cuda]
pip install timesfm[xreg-cuda]

Why This Matters

TimesFM is widely used for research and experimentation — many users run it on
laptops or CPU-based cloud instances. Forcing a CUDA install silently blocks
all of them from using the Flax backend or XReg covariates, with confusing
error messages that don't point to this root cause.

Note: jax[cuda] (without version suffix) is also deprecated upstream in
favor of jax[cuda12]. The fix above uses the current recommended name.

References

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions