fix: support MPS device for EdgeTAM#2269
Conversation
Greptile SummaryThis PR removes the hard-coded CUDA requirement in
Confidence Score: 5/5The change is a minimal, self-contained device-selection fix in one file with no regressions to existing CUDA behaviour. The three changed lines are straightforward: CUDA is tried first (preserving existing behaviour), MPS is the new fallback, and the error path is unchanged. The predictor and frame tensors are moved through the same stored device string, keeping the two call sites in sync. The output path already used .cpu().numpy() and remains device-agnostic. No logic outside edge_tam.py is touched by the fix itself. No files require special attention; the fix is isolated to dimos/models/segmentation/edge_tam.py. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[EdgeTAMProcessor init] --> B{CUDA available?}
B -- Yes --> C[device = cuda]
B -- No --> D{MPS available?}
D -- Yes --> E[device = mps]
D -- No --> F[raise RuntimeError]
C --> G[predictor to device]
E --> G
G --> H[predictor eval]
H --> I[prepare_frame called]
I --> J[img_tensor to device]
J --> K[append to inference_state images]
K --> L[process_results]
L --> M[mask_logits cpu numpy]
Reviews (2): Last reviewed commit: "fix: support MPS device for EdgeTAM" | Re-trigger Greptile |
065fe81 to
feb2726
Compare
Problem
EdgeTAM hardcodes CUDA for both the predictor and input frames, so it cannot run on Apple Silicon machines where PyTorch MPS/Metal is available but CUDA is not.
No linked issue.
Solution
Select CUDA when available, otherwise fall back to the PyTorch MPS backend when present. Store the selected device on the processor and move both the predictor and prepared frame tensors through that device.
Contributor License Agreement