Skip to content

Commit ac6a308

Browse files
committed
[python-package] Add return type annotations to predict methods in sklearn module
Add return type annotations to LGBMModel.predict, LGBMClassifier.predict, and LGBMClassifier.predict_proba methods. The predict_proba method requires an isinstance assertion for type narrowing in the binary classification branch, since predict() returns a union type but only returns np.ndarray when pred_contrib=False. This follows the existing error message pattern used in _get_label_from_constructed_dataset and similar helper functions. Fixes: #3867
1 parent 52dbf06 commit ac6a308

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

python-package/lightgbm/sklearn.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1100,7 +1100,7 @@ def predict(
11001100
pred_contrib: bool = False,
11011101
validate_features: bool = False,
11021102
**kwargs: Any,
1103-
):
1103+
) -> Union[np.ndarray, scipy.sparse.spmatrix, List[scipy.sparse.spmatrix]]:
11041104
"""Docstring is set after definition, using a template."""
11051105
if not self.__sklearn_is_fitted__():
11061106
raise LGBMNotFittedError("Estimator not fitted, call fit before exploiting the model.")
@@ -1593,7 +1593,7 @@ def predict(
15931593
pred_contrib: bool = False,
15941594
validate_features: bool = False,
15951595
**kwargs: Any,
1596-
):
1596+
) -> Union[np.ndarray, scipy.sparse.spmatrix, List[scipy.sparse.spmatrix]]:
15971597
"""Docstring is inherited from the LGBMModel."""
15981598
result = self.predict_proba(
15991599
X=X,
@@ -1623,7 +1623,7 @@ def predict_proba(
16231623
pred_contrib: bool = False,
16241624
validate_features: bool = False,
16251625
**kwargs: Any,
1626-
):
1626+
) -> Union[np.ndarray, scipy.sparse.spmatrix, List[scipy.sparse.spmatrix]]:
16271627
"""Docstring is set after definition, using a template."""
16281628
result = super().predict(
16291629
X=X,
@@ -1645,6 +1645,11 @@ def predict_proba(
16451645
elif self.__is_multiclass or raw_score or pred_leaf or pred_contrib: # type: ignore [operator]
16461646
return result
16471647
else:
1648+
error_msg = (
1649+
"predict() should return np.ndarray when pred_contrib=False. "
1650+
"If you're seeing this message, it's a bug in lightgbm. Please report it at https://github.com/microsoft/LightGBM/issues."
1651+
)
1652+
assert isinstance(result, np.ndarray), error_msg
16481653
return np.vstack((1.0 - result, result)).transpose()
16491654

16501655
predict_proba.__doc__ = _lgbmmodel_doc_predict.format(

0 commit comments

Comments
 (0)