Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion bitbots_vision/bitbots_vision/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def add(self, param_name, param_type=None, default=None, description=None, min=N
gen.add(
"yoeo_framework",
str,
description="The neural network framework that should be used ['pytorch', 'openvino', 'onnx', 'tvm']",
description="The neural network framework that should be used ['pytorch', 'openvino', 'onnx', 'tvm', 'rfdetr']",
)

gen.add(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def configure(cls, config: dict) -> None:

@staticmethod
def _verify_framework_parameter(framework: str) -> None:
if framework not in {"openvino", "onnx", "pytorch", "tvm"}:
if framework not in {"openvino", "onnx", "pytorch", "tvm", "rfdetr"}:
logger.error(f"Unknown neural network framework '{framework}'")

@classmethod
Expand Down
76 changes: 76 additions & 0 deletions bitbots_vision/bitbots_vision/vision_modules/yoeo/yoeo_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,78 @@ def _compute_new_prediction_for(self, image: np.ndarray) -> tuple[np.ndarray, np
return detections, segmentation


class RFDETREHandlerONNX(YOEOHandlerTemplate):
"""
RF-Detr handler for the ONNX framework.
"""

def __init__(
self,
config: dict,
model_directory: str,
det_class_names: list[str],
det_robot_class_ids: list[int],
seg_class_names: list[str],
):
super().__init__(config, model_directory, det_class_names, det_robot_class_ids, seg_class_names)

logger.debug(f"Entering {self.__class__.__name__} constructor")

onnx_path = YOEOPathGetter.get_rfdetr_onnx_file_path(model_directory)

try:
import onnxruntime
except ImportError as e:
raise ImportError("Could not import onnxruntime. The selected handler requires this package.") from e

logger.debug(f"Loading file...\n\t{onnx_path}")
self._inference_session = onnxruntime.InferenceSession(onnx_path)
self._input_layer = self._inference_session.get_inputs()[0]

print("input layer shape:", self._input_layer.shape)

self._img_preprocessor: utils.IImagePreProcessor = utils.DefaultImagePreProcessor(
tuple(self._input_layer.shape[2:])
)
self._det_postprocessor: utils.IDetectionPostProcessor = utils.DefaultDetectionPostProcessor(
image_preprocessor=self._img_preprocessor,
output_img_size=self._input_layer.shape[2],
conf_thresh=config["yoeo_conf_threshold"],
nms_thresh=config["yoeo_nms_threshold"],
robot_class_ids=self.get_robot_class_ids(),
)
# self._seg_postprocessor: utils.ISegmentationPostProcessor = utils.DefaultSegmentationPostProcessor(
# self._img_preprocessor
# )

logger.debug(f"Leaving {self.__class__.__name__} constructor")

def configure(self, config: dict) -> None:
super().configure(config)
self._det_postprocessor.configure(
image_preprocessor=self._img_preprocessor,
output_img_size=self._input_layer.shape[2],
conf_thresh=config["yoeo_conf_threshold"],
nms_thresh=config["yoeo_nms_threshold"],
robot_class_ids=self.get_robot_class_ids(),
)

@staticmethod
def model_files_exist(model_directory: str) -> bool:
return os.path.exists(YOEOPathGetter.get_rfdetr_onnx_file_path(model_directory))

def _compute_new_prediction_for(self, image: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
preproccessed_image = self._img_preprocessor.process(image)

network_input = preproccessed_image.reshape(self._input_layer.shape)
outputs = self._inference_session.run(None, {"InputLayer": network_input.astype(np.float32)})

detections = self._det_postprocessor.process(outputs)
# segmentation = self._seg_postprocessor.process(outputs[1])

return detections, None


class YOEOPathGetter:
"""
PathGetter class for all YOEO handlers. They idea behind this class is to have all path information in one place so
Expand Down Expand Up @@ -575,3 +647,7 @@ def get_tvm_params_file_path(cls, model_directory: str) -> str:
@classmethod
def get_tvm_so_file_path(cls, model_directory: str) -> str:
return cls._assemble_full_path(model_directory, "tvm", "yoeo.so")

@classmethod
def get_rfdetr_onnx_file_path(cls, model_directory) -> str:
return cls._assemble_full_path(model_directory, "onnx", "rfdetr.onnx")
2 changes: 1 addition & 1 deletion bitbots_vision/config/visionparams.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ bitbots_vision:
yoeo_model_path: '2022_10_07_flo_torso21_yoeox'
yoeo_nms_threshold: 0.4 # Non-maximum suppression threshold
yoeo_conf_threshold: 0.5 # YOEO confidence threshold
yoeo_framework: 'tvm' # The neural network framework that should be used ['openvino', 'onnx', 'pytorch', 'tvm']
yoeo_framework: 'rfdetr' # The neural network framework that should be used ['openvino', 'onnx', 'pytorch', 'tvm', 'rfdetr']

ball_candidate_rating_threshold: 0.5 # A threshold for the minimum candidate rating
ball_candidate_max_count: 1 # The maximum number of balls that should be published
Expand Down
2 changes: 1 addition & 1 deletion bitbots_vision/config/visionparams_sim.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ bitbots_vision:
yoeo_model_path: '2022_10_07_flo_torso21_yoeox'
yoeo_nms_threshold: 0.4 # Non-maximum suppression threshold
yoeo_conf_threshold: 0.5 # YOEO confidence threshold
yoeo_framework: 'pytorch' # The neural network framework that should be used ['openvino', 'onnx', 'pytorch', 'tvm']
yoeo_framework: 'pytorch' # The neural network framework that should be used ['openvino', 'onnx', 'pytorch', 'tvm', 'rfdetr']

ball_candidate_rating_threshold: 0.5 # A threshold for the minimum candidate rating
ball_candidate_max_count: 1 # The maximum number of balls that should be published
Expand Down
Loading