-
Notifications
You must be signed in to change notification settings - Fork 4k
[python-package] Add return type annotations to predict methods in sklearn module #7116
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
[python-package] Add return type annotations to predict methods in sklearn module #7116
Conversation
jameslamb
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your interest in LightGBM and for working on this!
There are multiple predict() methods implemented in this file, could you please try fixing the hints on all of them? I think they all are passing through output from Booster.predict() without type coercion so all should be the same.
python-package/lightgbm/sklearn.py:1093: error: Function is missing a return type annotation [no-untyped-def]
python-package/lightgbm/sklearn.py:1586: error: Function is missing a return type annotation [no-untyped-def]
python-package/lightgbm/sklearn.py:1616: error: Function is missing a return type annotation [no-untyped-def
| validate_features: bool = False, | ||
| **kwargs: Any, | ||
| ): | ||
| ) -> Union[np.ndarray, scipy.sparse.spmatrix, List[scipy.sparse.spmatrix]]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you're right, this should just be identical to the return type hint on Booster.predict()!
I'd been avoiding investigating these because I thought it was possible for pandas outputs to make it in here via the set_output() API (scikit-learn docs) but I think that's something scikit-learn estimators have to opt into and implement themselves, and we haven't done that in LightGBM.
Since you've identified this... let's enforce that relationship by moving the Booster.predict() type hint into a shared type hint variable and re-using it in both places.
- add a
_LGBM_PredictReturnTypeafter this:LightGBM/python-package/lightgbm/basic.py
Lines 133 to 140 in 52dbf06
_LGBM_PredictDataType = Union[ str, Path, np.ndarray, pd_DataFrame, scipy.sparse.spmatrix, pa_Table, ] - re-use it in all the relevant places in
basic.pyandsklearn.py
|
side note: in 2026 I am finally eliminating that |
d88c791 to
b99db2b
Compare
…learn module Add return type annotations to LGBMModel.predict, LGBMClassifier.predict, and LGBMClassifier.predict_proba methods. The predict_proba method requires an isinstance assertion for type narrowing in the binary classification branch, since predict() returns a union type but only returns np.ndarray when pred_contrib=False. This follows the existing error message pattern used in _get_label_from_constructed_dataset and similar helper functions. Fixes: microsoft#3867
b99db2b to
ac6a308
Compare
jameslamb
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for expanding this to all of the predict() methods. I'll wait to re-review until my suggestions in #7116 (comment) are fully addressed.
Please also do comment on the suggestion threads when you push new changes, explaining what you did.
Summary
Add return type annotations to all predict methods in sklearn module:
LGBMModel.predictLGBMClassifier.predictLGBMClassifier.predict_probaType narrowing in predict_proba
The
predict()method returnsUnion[np.ndarray, spmatrix, List[spmatrix]]. The else branch inpredict_probadoesnp.vstack((1.0 - result, result))which only works withnp.ndarray.This else branch only runs when
pred_contrib=False, andpredict()always returnsnp.ndarraywhenpred_contrib=False. Added an assertion with error message to tell mypy this, following the pattern in_get_label_from_constructed_dataset.Mypy Errors (before fix)
Related Issue
Fixes: #3867
Test Plan
pre-commit run --files python-package/lightgbm/sklearn.py- all checks passedno-untyped-deferrors for predict methods