diff --git a/dimos/models/segmentation/edge_tam.py b/dimos/models/segmentation/edge_tam.py index 88ec707b1c..e706b4c6d8 100644 --- a/dimos/models/segmentation/edge_tam.py +++ b/dimos/models/segmentation/edge_tam.py @@ -53,6 +53,7 @@ class EdgeTAMProcessor(Detector): _frame_count: int _is_tracking: bool _buffer_size: int + _device: str def __init__( self, @@ -62,8 +63,14 @@ def __init__( if not local_config_path.exists(): raise FileNotFoundError(f"EdgeTAM config not found at {local_config_path}") - if not torch.cuda.is_available(): - raise RuntimeError("EdgeTAM requires a CUDA-capable GPU") + if torch.cuda.is_available(): + self._device = "cuda" + elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): + self._device = "mps" + else: + raise RuntimeError( + "EdgeTAM requires a CUDA-capable GPU or an Apple Silicon MPS/Metal backend." + ) cfg = OmegaConf.load(local_config_path) @@ -98,7 +105,7 @@ def __init__( if unexpected_keys: raise RuntimeError("Unexpected keys in checkpoint") - self._predictor = self._predictor.to("cuda") + self._predictor = self._predictor.to(self._device) self._predictor.eval() self._inference_state = None @@ -124,7 +131,7 @@ def _prepare_frame(self, image: Image) -> torch.Tensor: img_np /= img_std img_tensor = torch.from_numpy(img_np).permute(2, 0, 1).float() - img_tensor = img_tensor.cuda() + img_tensor = img_tensor.to(self._device) return img_tensor