diff --git a/bitbots_vision/bitbots_vision/params.py b/bitbots_vision/bitbots_vision/params.py index 3fc386011..5e9943125 100755 --- a/bitbots_vision/bitbots_vision/params.py +++ b/bitbots_vision/bitbots_vision/params.py @@ -94,7 +94,7 @@ def add(self, param_name, param_type=None, default=None, description=None, min=N gen.add( "yoeo_framework", str, - description="The neural network framework that should be used ['pytorch', 'openvino', 'onnx', 'tvm']", + description="The neural network framework that should be used ['pytorch', 'openvino', 'onnx', 'tvm', 'rfdetr']", ) gen.add( diff --git a/bitbots_vision/bitbots_vision/vision_modules/yoeo/object_manager.py b/bitbots_vision/bitbots_vision/vision_modules/yoeo/object_manager.py index 3aadfb974..020b0ad62 100644 --- a/bitbots_vision/bitbots_vision/vision_modules/yoeo/object_manager.py +++ b/bitbots_vision/bitbots_vision/vision_modules/yoeo/object_manager.py @@ -90,7 +90,7 @@ def configure(cls, config: dict) -> None: @staticmethod def _verify_framework_parameter(framework: str) -> None: - if framework not in {"openvino", "onnx", "pytorch", "tvm"}: + if framework not in {"openvino", "onnx", "pytorch", "tvm", "rfdetr"}: logger.error(f"Unknown neural network framework '{framework}'") @classmethod diff --git a/bitbots_vision/bitbots_vision/vision_modules/yoeo/yoeo_handlers.py b/bitbots_vision/bitbots_vision/vision_modules/yoeo/yoeo_handlers.py index 9409dc3ea..3e074b61d 100644 --- a/bitbots_vision/bitbots_vision/vision_modules/yoeo/yoeo_handlers.py +++ b/bitbots_vision/bitbots_vision/vision_modules/yoeo/yoeo_handlers.py @@ -526,6 +526,78 @@ def _compute_new_prediction_for(self, image: np.ndarray) -> tuple[np.ndarray, np return detections, segmentation +class RFDETREHandlerONNX(YOEOHandlerTemplate): + """ + RF-Detr handler for the ONNX framework. + """ + + def __init__( + self, + config: dict, + model_directory: str, + det_class_names: list[str], + det_robot_class_ids: list[int], + seg_class_names: list[str], + ): + super().__init__(config, model_directory, det_class_names, det_robot_class_ids, seg_class_names) + + logger.debug(f"Entering {self.__class__.__name__} constructor") + + onnx_path = YOEOPathGetter.get_rfdetr_onnx_file_path(model_directory) + + try: + import onnxruntime + except ImportError as e: + raise ImportError("Could not import onnxruntime. The selected handler requires this package.") from e + + logger.debug(f"Loading file...\n\t{onnx_path}") + self._inference_session = onnxruntime.InferenceSession(onnx_path) + self._input_layer = self._inference_session.get_inputs()[0] + + print("input layer shape:", self._input_layer.shape) + + self._img_preprocessor: utils.IImagePreProcessor = utils.DefaultImagePreProcessor( + tuple(self._input_layer.shape[2:]) + ) + self._det_postprocessor: utils.IDetectionPostProcessor = utils.DefaultDetectionPostProcessor( + image_preprocessor=self._img_preprocessor, + output_img_size=self._input_layer.shape[2], + conf_thresh=config["yoeo_conf_threshold"], + nms_thresh=config["yoeo_nms_threshold"], + robot_class_ids=self.get_robot_class_ids(), + ) + # self._seg_postprocessor: utils.ISegmentationPostProcessor = utils.DefaultSegmentationPostProcessor( + # self._img_preprocessor + # ) + + logger.debug(f"Leaving {self.__class__.__name__} constructor") + + def configure(self, config: dict) -> None: + super().configure(config) + self._det_postprocessor.configure( + image_preprocessor=self._img_preprocessor, + output_img_size=self._input_layer.shape[2], + conf_thresh=config["yoeo_conf_threshold"], + nms_thresh=config["yoeo_nms_threshold"], + robot_class_ids=self.get_robot_class_ids(), + ) + + @staticmethod + def model_files_exist(model_directory: str) -> bool: + return os.path.exists(YOEOPathGetter.get_rfdetr_onnx_file_path(model_directory)) + + def _compute_new_prediction_for(self, image: np.ndarray) -> tuple[np.ndarray, np.ndarray]: + preproccessed_image = self._img_preprocessor.process(image) + + network_input = preproccessed_image.reshape(self._input_layer.shape) + outputs = self._inference_session.run(None, {"InputLayer": network_input.astype(np.float32)}) + + detections = self._det_postprocessor.process(outputs) + # segmentation = self._seg_postprocessor.process(outputs[1]) + + return detections, None + + class YOEOPathGetter: """ PathGetter class for all YOEO handlers. They idea behind this class is to have all path information in one place so @@ -575,3 +647,7 @@ def get_tvm_params_file_path(cls, model_directory: str) -> str: @classmethod def get_tvm_so_file_path(cls, model_directory: str) -> str: return cls._assemble_full_path(model_directory, "tvm", "yoeo.so") + + @classmethod + def get_rfdetr_onnx_file_path(cls, model_directory) -> str: + return cls._assemble_full_path(model_directory, "onnx", "rfdetr.onnx") diff --git a/bitbots_vision/config/visionparams.yaml b/bitbots_vision/config/visionparams.yaml index bca73cb0a..82f6427df 100644 --- a/bitbots_vision/config/visionparams.yaml +++ b/bitbots_vision/config/visionparams.yaml @@ -20,7 +20,7 @@ bitbots_vision: yoeo_model_path: '2022_10_07_flo_torso21_yoeox' yoeo_nms_threshold: 0.4 # Non-maximum suppression threshold yoeo_conf_threshold: 0.5 # YOEO confidence threshold - yoeo_framework: 'tvm' # The neural network framework that should be used ['openvino', 'onnx', 'pytorch', 'tvm'] + yoeo_framework: 'rfdetr' # The neural network framework that should be used ['openvino', 'onnx', 'pytorch', 'tvm', 'rfdetr'] ball_candidate_rating_threshold: 0.5 # A threshold for the minimum candidate rating ball_candidate_max_count: 1 # The maximum number of balls that should be published diff --git a/bitbots_vision/config/visionparams_sim.yaml b/bitbots_vision/config/visionparams_sim.yaml index 016b4ec89..c3bdb9271 100644 --- a/bitbots_vision/config/visionparams_sim.yaml +++ b/bitbots_vision/config/visionparams_sim.yaml @@ -20,7 +20,7 @@ bitbots_vision: yoeo_model_path: '2022_10_07_flo_torso21_yoeox' yoeo_nms_threshold: 0.4 # Non-maximum suppression threshold yoeo_conf_threshold: 0.5 # YOEO confidence threshold - yoeo_framework: 'pytorch' # The neural network framework that should be used ['openvino', 'onnx', 'pytorch', 'tvm'] + yoeo_framework: 'pytorch' # The neural network framework that should be used ['openvino', 'onnx', 'pytorch', 'tvm', 'rfdetr'] ball_candidate_rating_threshold: 0.5 # A threshold for the minimum candidate rating ball_candidate_max_count: 1 # The maximum number of balls that should be published