Skip to content

fix: add MPS (Apple Silicon) support for PyTorch backend#415

Open
kime541200 wants to merge 1 commit into
google-research:masterfrom
kime541200:fix/mps-apple-silicon-support
Open

fix: add MPS (Apple Silicon) support for PyTorch backend#415
kime541200 wants to merge 1 commit into
google-research:masterfrom
kime541200:fix/mps-apple-silicon-support

Conversation

@kime541200

Copy link
Copy Markdown

Summary

  • Auto-detect MPS device in TimesFM_2p5_200M_torch_module.__init__: adds an elif torch.backends.mps.is_available() branch so Apple Silicon GPUs are used automatically, without requiring any manual workaround from callers.
  • Fix tensor dtype conversion order in timesfm_2p5_torch.py: convert to float32/bool before moving to device (was .to(device).to(float32)). MPS does not support float64, so the previous order raised Cannot convert a MPS Tensor to float64 when NumPy's default float64 arrays were moved to MPS first.
  • Fix padding array dtype in timesfm_2p5_base.py: explicitly pass dtype=np.float32 when creating padding arrays, preventing NumPy's implicit float64 promotion from propagating into the model.

Test plan

  • Run inference on Apple Silicon Mac (torch.backends.mps.is_available() == True) and confirm no float64 errors
  • Confirm existing CUDA and CPU paths are unaffected (device detection logic unchanged for those branches)

🤖 Generated with Claude Code

@google-cla

google-cla Bot commented May 1, 2026

Copy link
Copy Markdown

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

- Auto-detect MPS device in TimesFM_2p5_200M_torch_module.__init__
- Fix tensor dtype conversion order: cast to float32/bool before moving
  to device, avoiding MPS float64 incompatibility
- Explicitly set dtype=np.float32 on padding arrays to prevent NumPy's
  implicit float64 promotion from propagating to MPS tensors
@kime541200 kime541200 force-pushed the fix/mps-apple-silicon-support branch from 2d66983 to 4c2ca4f Compare June 12, 2026 04:09
@rajatsen91

Copy link
Copy Markdown
Collaborator

LGTM

@rajatsen91 rajatsen91 requested a review from siriuz42 June 12, 2026 15:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants